Skip to content

Commit ab7851e

Browse files
committed
refine input dtype check
1 parent 599f223 commit ab7851e

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

tests/run_pretrained_models.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class Test(object):
8282
target = []
8383

8484
def __init__(self, url, local, make_input, input_names, output_names,
85-
disabled=False, more_inputs=None, rtol=0.01, atol=1e-6,
85+
disabled=False, rtol=0.01, atol=1e-6,
8686
check_only_shape=False, model_type="frozen", force_input_shape=False,
8787
skip_tensorflow=False, opset_constraints=None):
8888
self.url = url
@@ -91,7 +91,6 @@ def __init__(self, url, local, make_input, input_names, output_names,
9191
self.input_names = input_names
9292
self.output_names = output_names
9393
self.disabled = disabled
94-
self.more_inputs = more_inputs
9594
self.rtol = rtol
9695
self.atol = atol
9796
self.check_only_shape = check_only_shape
@@ -212,34 +211,30 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
212211
else:
213212
graph_def, input_names, outputs = loader.from_graphdef(model_path, input_names, outputs)
214213

215-
# create the input data
216-
inputs = {}
217-
for k, v in self.input_names.items():
218-
if k not in input_names:
219-
continue
220-
if isinstance(v, six.text_type) and v.startswith("np."):
221-
inputs[k] = eval(v) # pylint: disable=eval-used
222-
else:
223-
inputs[k] = self.make_input(v)
224-
if self.more_inputs:
225-
for k, v in self.more_inputs.items():
226-
inputs[k] = v
227-
228-
graph_def = tf2onnx.tfonnx.tf_optimize(inputs.keys(), self.output_names, graph_def, fold_const)
214+
# remove unused input names
215+
input_names = list(set(input_names).intersection(self.input_names.keys()))
216+
graph_def = tf2onnx.tfonnx.tf_optimize(input_names, self.output_names, graph_def, fold_const)
229217
if utils.is_debug_mode():
230218
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
219+
220+
inputs = {}
231221
shape_override = {}
232222
g = tf.import_graph_def(graph_def, name='')
233223
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=g) as sess:
234-
235-
# fix inputs if needed
236-
for k in inputs.keys(): # pylint: disable=consider-iterating-dictionary
224+
# create the input data
225+
for k in input_names:
226+
v = self.input_names[k]
237227
t = sess.graph.get_tensor_by_name(k)
238-
dtype = tf.as_dtype(t.dtype).name
239-
v = inputs[k]
240-
if dtype != v.dtype:
241-
logger.warning("input dtype doesn't match tensorflow's")
242-
inputs[k] = np.array(v, dtype=dtype)
228+
expected_dtype = tf.as_dtype(t.dtype).name
229+
if isinstance(v, six.text_type) and v.startswith("np."):
230+
np_value = eval(v) # pylint: disable=eval-used
231+
if expected_dtype != np_value.dtype:
232+
logger.warning("dtype mismatch for input %s: expected=%s, actual=%s", k, expected_dtype,
233+
np_value.dtype)
234+
inputs[k] = np_value.astype(expected_dtype)
235+
else:
236+
inputs[k] = self.make_input(v).astype(expected_dtype)
237+
243238
if self.force_input_shape:
244239
for k, v in inputs.items():
245240
shape_override[k] = list(v.shape)
@@ -405,7 +400,7 @@ def load_tests_from_yaml(path):
405400
opset_constraints.append(c)
406401

407402
kwargs = {}
408-
for kw in ["rtol", "atol", "disabled", "more_inputs", "check_only_shape", "model_type",
403+
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type",
409404
"skip_tensorflow", "force_input_shape"]:
410405
if settings.get(kw) is not None:
411406
kwargs[kw] = settings[kw]

0 commit comments

Comments
 (0)