@@ -82,7 +82,7 @@ class Test(object):
82
82
target = []
83
83
84
84
def __init__ (self , url , local , make_input , input_names , output_names ,
85
- disabled = False , more_inputs = None , rtol = 0.01 , atol = 1e-6 ,
85
+ disabled = False , rtol = 0.01 , atol = 1e-6 ,
86
86
check_only_shape = False , model_type = "frozen" , force_input_shape = False ,
87
87
skip_tensorflow = False , opset_constraints = None ):
88
88
self .url = url
@@ -91,7 +91,6 @@ def __init__(self, url, local, make_input, input_names, output_names,
91
91
self .input_names = input_names
92
92
self .output_names = output_names
93
93
self .disabled = disabled
94
- self .more_inputs = more_inputs
95
94
self .rtol = rtol
96
95
self .atol = atol
97
96
self .check_only_shape = check_only_shape
@@ -212,34 +211,30 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
212
211
else :
213
212
graph_def , input_names , outputs = loader .from_graphdef (model_path , input_names , outputs )
214
213
215
- # create the input data
216
- inputs = {}
217
- for k , v in self .input_names .items ():
218
- if k not in input_names :
219
- continue
220
- if isinstance (v , six .text_type ) and v .startswith ("np." ):
221
- inputs [k ] = eval (v ) # pylint: disable=eval-used
222
- else :
223
- inputs [k ] = self .make_input (v )
224
- if self .more_inputs :
225
- for k , v in self .more_inputs .items ():
226
- inputs [k ] = v
227
-
228
- graph_def = tf2onnx .tfonnx .tf_optimize (inputs .keys (), self .output_names , graph_def , fold_const )
214
+ # remove unused input names
215
+ input_names = list (set (input_names ).intersection (self .input_names .keys ()))
216
+ graph_def = tf2onnx .tfonnx .tf_optimize (input_names , self .output_names , graph_def , fold_const )
229
217
if utils .is_debug_mode ():
230
218
utils .save_protobuf (os .path .join (TEMP_DIR , name + "_after_tf_optimize.pb" ), graph_def )
219
+
220
+ inputs = {}
231
221
shape_override = {}
232
222
g = tf .import_graph_def (graph_def , name = '' )
233
223
with tf .Session (config = tf .ConfigProto (allow_soft_placement = True ), graph = g ) as sess :
234
-
235
- # fix inputs if needed
236
- for k in inputs . keys (): # pylint: disable=consider-iterating-dictionary
224
+ # create the input data
225
+ for k in input_names :
226
+ v = self . input_names [ k ]
237
227
t = sess .graph .get_tensor_by_name (k )
238
- dtype = tf .as_dtype (t .dtype ).name
239
- v = inputs [k ]
240
- if dtype != v .dtype :
241
- logger .warning ("input dtype doesn't match tensorflow's" )
242
- inputs [k ] = np .array (v , dtype = dtype )
228
+ expected_dtype = tf .as_dtype (t .dtype ).name
229
+ if isinstance (v , six .text_type ) and v .startswith ("np." ):
230
+ np_value = eval (v ) # pylint: disable=eval-used
231
+ if expected_dtype != np_value .dtype :
232
+ logger .warning ("dtype mismatch for input %s: expected=%s, actual=%s" , k , expected_dtype ,
233
+ np_value .dtype )
234
+ inputs [k ] = np_value .astype (expected_dtype )
235
+ else :
236
+ inputs [k ] = self .make_input (v ).astype (expected_dtype )
237
+
243
238
if self .force_input_shape :
244
239
for k , v in inputs .items ():
245
240
shape_override [k ] = list (v .shape )
@@ -405,7 +400,7 @@ def load_tests_from_yaml(path):
405
400
opset_constraints .append (c )
406
401
407
402
kwargs = {}
408
- for kw in ["rtol" , "atol" , "disabled" , "more_inputs" , " check_only_shape" , "model_type" ,
403
+ for kw in ["rtol" , "atol" , "disabled" , "check_only_shape" , "model_type" ,
409
404
"skip_tensorflow" , "force_input_shape" ]:
410
405
if settings .get (kw ) is not None :
411
406
kwargs [kw ] = settings [kw ]
0 commit comments