1
1
"""
2
2
Copyright (c) 2018-2020 Intel Corporation
3
-
4
3
Licensed under the Apache License, Version 2.0 (the "License");
5
4
you may not use this file except in compliance with the License.
6
5
You may obtain a copy of the License at
7
-
8
6
http://www.apache.org/licenses/LICENSE-2.0
9
-
10
7
Unless required by applicable law or agreed to in writing, software
11
8
distributed under the License is distributed on an "AS IS" BASIS,
12
9
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
16
from pathlib import Path
20
17
21
18
import numpy as np
22
- import mxnet
23
19
24
20
from .launcher import Launcher , LauncherConfigValidator , ListInputsField
25
21
from ..config import PathField , StringField , NumberField , ConfigError
@@ -58,6 +54,13 @@ def parameters(cls):
58
54
return parameters
59
55
60
56
def __init__ (self , config_entry : dict , * args , ** kwargs ):
57
+ try :
58
+ import mxnet # pylint: disable=C0415
59
+ self .mxnet = mxnet
60
+ except ImportError as import_error :
61
+ raise ValueError (
62
+ "MXNet isn't installed. Please, install it before using. \n {}" .format (import_error .msg )
63
+ )
61
64
super ().__init__ (config_entry , * args , ** kwargs )
62
65
self ._delayed_model_loading = kwargs .get ('delayed_model_loading' , False )
63
66
@@ -78,9 +81,9 @@ def __init__(self, config_entry: dict, *args, **kwargs):
78
81
identifier = match .group ('identifier' )
79
82
if identifier is None :
80
83
identifier = 0
81
- device_context = mxnet .gpu (int (identifier ))
84
+ device_context = self . mxnet .gpu (int (identifier ))
82
85
else :
83
- device_context = mxnet .cpu ()
86
+ device_context = self . mxnet .cpu ()
84
87
85
88
# Get batch from config or 1
86
89
self ._batch = self .config .get ('batch' , 1 )
@@ -109,7 +112,7 @@ def batch(self):
109
112
110
113
def fit_to_input (self , data , input_layer , layout , precision ):
111
114
data = np .transpose (data , layout )
112
- return mxnet .nd .array (data .astype (precision ) if precision else data )
115
+ return self . mxnet .nd .array (data .astype (precision ) if precision else data )
113
116
114
117
@property
115
118
def inputs (self ):
@@ -125,9 +128,9 @@ def predict(self, inputs, metadata=None, **kwargs):
125
128
"""
126
129
results = []
127
130
for infer_input in inputs :
128
- data_iter = mxnet .io .NDArrayIter (
131
+ data_iter = self . mxnet .io .NDArrayIter (
129
132
data = infer_input , label = None , batch_size = self .batch )
130
- data_batch = mxnet .io .DataBatch (data = data_iter .data_list )
133
+ data_batch = self . mxnet .io .DataBatch (data = data_iter .data_list )
131
134
132
135
# Infer
133
136
self .module .forward (data_batch )
0 commit comments