Skip to content

Commit f2f540b

Browse files
authored
Merge pull request #502 from nbcsm/test
enhance pre-trained model test
2 parents c30c692 + ab7851e commit f2f540b

File tree

6 files changed

+145
-93
lines changed

6 files changed

+145
-93
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,8 @@ def run_onnxcaffe2(self, onnx_graph, inputs):
5555
results = prepared_backend.run(inputs)
5656
return results
5757

58-
def run_onnxmsrtnext(self, model_path, inputs, output_names):
59-
"""Run test against msrt-next backend."""
60-
import lotus
61-
m = lotus.InferenceSession(model_path)
62-
results = m.run(output_names, inputs)
63-
return results
64-
6558
def run_onnxruntime(self, model_path, inputs, output_names):
66-
"""Run test against msrt-next backend."""
59+
"""Run test against onnxruntime backend."""
6760
import onnxruntime as rt
6861
m = rt.InferenceSession(model_path)
6962
results = m.run(output_names, inputs)
@@ -73,9 +66,7 @@ def run_backend(self, g, outputs, input_dict):
7366
model_proto = g.make_model("test")
7467
model_path = self.save_onnx_model(model_proto, input_dict)
7568

76-
if self.config.backend == "onnxmsrtnext":
77-
y = self.run_onnxmsrtnext(model_path, input_dict, outputs)
78-
elif self.config.backend == "onnxruntime":
69+
if self.config.backend == "onnxruntime":
7970
y = self.run_onnxruntime(model_path, input_dict, outputs)
8071
elif self.config.backend == "caffe2":
8172
y = self.run_onnxcaffe2(model_proto, input_dict)

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def load():
102102
if "pytest" not in sys.argv[0]:
103103
parser = argparse.ArgumentParser()
104104
parser.add_argument("--backend", default=config.backend,
105-
choices=["caffe2", "onnxmsrtnext", "onnxruntime"],
105+
choices=["caffe2", "onnxruntime"],
106106
help="backend to test against")
107107
parser.add_argument("--opset", type=int, default=config.opset, help="opset to test against")
108108
parser.add_argument("--target", default=",".join(config.target), choices=constants.POSSIBLE_TARGETS,

tests/run_pretrained_models.py

Lines changed: 126 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
import argparse
1212
import os
13+
import re
1314
import sys
1415
import tarfile
1516
import time
1617
import zipfile
18+
from collections import namedtuple
1719

1820
import PIL.Image
1921
import numpy as np
@@ -70,6 +72,8 @@ def get_ramp(shape):
7072
"get_ramp": get_ramp
7173
}
7274

75+
OpsetConstraint = namedtuple("OpsetConstraint", "domain, min_version, max_version, excluded_version")
76+
7377

7478
class Test(object):
7579
"""Main Test class."""
@@ -78,16 +82,15 @@ class Test(object):
7882
target = []
7983

8084
def __init__(self, url, local, make_input, input_names, output_names,
81-
disabled=False, more_inputs=None, rtol=0.01, atol=1e-6,
85+
disabled=False, rtol=0.01, atol=1e-6,
8286
check_only_shape=False, model_type="frozen", force_input_shape=False,
83-
skip_tensorflow=False):
87+
skip_tensorflow=False, opset_constraints=None):
8488
self.url = url
8589
self.make_input = make_input
8690
self.local = local
8791
self.input_names = input_names
8892
self.output_names = output_names
8993
self.disabled = disabled
90-
self.more_inputs = more_inputs
9194
self.rtol = rtol
9295
self.atol = atol
9396
self.check_only_shape = check_only_shape
@@ -97,9 +100,10 @@ def __init__(self, url, local, make_input, input_names, output_names,
97100
self.model_type = model_type
98101
self.force_input_shape = force_input_shape
99102
self.skip_tensorflow = skip_tensorflow
103+
self.opset_constraints = opset_constraints
100104

101-
def download_file(self):
102-
"""Download file from url."""
105+
def download_model(self):
106+
"""Download model from url."""
103107
cache_dir = Test.cache_dir
104108
if not os.path.exists(cache_dir):
105109
os.makedirs(cache_dir)
@@ -163,21 +167,8 @@ def run_caffe2(self, name, model_proto, inputs):
163167
self.onnx_runtime = time.time() - start
164168
return results
165169

166-
def run_onnxmsrtnext(self, name, model_proto, inputs):
167-
"""Run test against msrt-next backend."""
168-
import lotus
169-
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto)
170-
m = lotus.InferenceSession(model_path)
171-
results = m.run(self.output_names, inputs)
172-
if self.perf:
173-
start = time.time()
174-
for _ in range(PERFITER):
175-
_ = m.run(self.output_names, inputs)
176-
self.onnx_runtime = time.time() - start
177-
return results
178-
179170
def run_onnxruntime(self, name, model_proto, inputs):
180-
"""Run test against msrt-next backend."""
171+
"""Run test against onnxruntime backend."""
181172
import onnxruntime as rt
182173
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=True)
183174
logger.info("Model saved to %s", model_path)
@@ -200,19 +191,17 @@ def create_onnx_file(name, model_proto, inputs, outdir):
200191
def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_opset=None,
201192
perf=None, fold_const=None):
202193
"""Run complete test against backend."""
203-
logger.info("===================================")
204-
logger.info("Running %s", name)
205194
self.perf = perf
206195

207196
# get the model
208197
if self.url:
209-
_, dir_name = self.download_file()
198+
_, dir_name = self.download_model()
199+
logger.info("Downloaded to %s", dir_name)
210200
model_path = os.path.join(dir_name, self.local)
211201
else:
212202
model_path = self.local
213-
dir_name = os.path.dirname(self.local)
214-
logger.info("Downloaded to %s", model_path)
215203

204+
logger.info("Load model from %s", model_path)
216205
input_names = list(self.input_names.keys())
217206
outputs = self.output_names
218207
if self.model_type in ["checkpoint"]:
@@ -222,34 +211,30 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
222211
else:
223212
graph_def, input_names, outputs = loader.from_graphdef(model_path, input_names, outputs)
224213

225-
# create the input data
226-
inputs = {}
227-
for k, v in self.input_names.items():
228-
if k not in input_names:
229-
continue
230-
if isinstance(v, six.text_type) and v.startswith("np."):
231-
inputs[k] = eval(v) # pylint: disable=eval-used
232-
else:
233-
inputs[k] = self.make_input(v)
234-
if self.more_inputs:
235-
for k, v in self.more_inputs.items():
236-
inputs[k] = v
237-
238-
graph_def = tf2onnx.tfonnx.tf_optimize(inputs.keys(), self.output_names, graph_def, fold_const)
214+
# remove unused input names
215+
input_names = list(set(input_names).intersection(self.input_names.keys()))
216+
graph_def = tf2onnx.tfonnx.tf_optimize(input_names, self.output_names, graph_def, fold_const)
239217
if utils.is_debug_mode():
240218
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
219+
220+
inputs = {}
241221
shape_override = {}
242222
g = tf.import_graph_def(graph_def, name='')
243223
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=g) as sess:
244-
245-
# fix inputs if needed
246-
for k in inputs.keys(): # pylint: disable=consider-iterating-dictionary
224+
# create the input data
225+
for k in input_names:
226+
v = self.input_names[k]
247227
t = sess.graph.get_tensor_by_name(k)
248-
dtype = tf.as_dtype(t.dtype).name
249-
v = inputs[k]
250-
if dtype != v.dtype:
251-
logger.warning("input dtype doesn't match tensorflow's")
252-
inputs[k] = np.array(v, dtype=dtype)
228+
expected_dtype = tf.as_dtype(t.dtype).name
229+
if isinstance(v, six.text_type) and v.startswith("np."):
230+
np_value = eval(v) # pylint: disable=eval-used
231+
if expected_dtype != np_value.dtype:
232+
logger.warning("dtype mismatch for input %s: expected=%s, actual=%s", k, expected_dtype,
233+
np_value.dtype)
234+
inputs[k] = np_value.astype(expected_dtype)
235+
else:
236+
inputs[k] = self.make_input(v).astype(expected_dtype)
237+
253238
if self.force_input_shape:
254239
for k, v in inputs.items():
255240
shape_override[k] = list(v.shape)
@@ -279,8 +264,6 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
279264
onnx_results = None
280265
if backend == "caffe2":
281266
onnx_results = self.run_caffe2(name, model_proto, inputs)
282-
elif backend == "onnxmsrtnext":
283-
onnx_results = self.run_onnxmsrtnext(name, model_proto, inputs)
284267
elif backend == "onnxruntime":
285268
onnx_results = self.run_onnxruntime(name, model_proto, inputs)
286269
else:
@@ -307,6 +290,41 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
307290

308291
return False
309292

293+
def check_opset_constraints(self, opset, extra_opset=None):
294+
""" Return (condition, reason) tuple, condition is True if constraints are met. """
295+
if not self.opset_constraints:
296+
return True, None
297+
298+
opsets = {"onnx": opset}
299+
if extra_opset:
300+
for e in extra_opset:
301+
opsets[e.domain] = e.version
302+
303+
for constraint in self.opset_constraints:
304+
domain = constraint.domain
305+
opset_version = opsets.get(domain)
306+
if not opset_version:
307+
return False, "conversion requires opset {}".format(domain)
308+
309+
if constraint.min_version and opset_version < constraint.min_version:
310+
reason = "conversion requires opset {} >= {}".format(domain, constraint.min_version)
311+
return False, reason
312+
313+
if constraint.max_version and opset_version > constraint.max_version:
314+
reason = "conversion requires opset {} <= {}".format(domain, constraint.max_version)
315+
return False, reason
316+
317+
if constraint.excluded_version:
318+
if utils.is_list_or_tuple(constraint.excluded_version):
319+
skip = opset_version in constraint.excluded_version
320+
else:
321+
skip = opset_version == constraint.excluded_version
322+
if skip:
323+
reason = "conversion requires opset {} != {}".format(domain, constraint.excluded_version)
324+
return False, reason
325+
326+
return True, None
327+
310328

311329
def get_args():
312330
"""Parse commandline."""
@@ -316,7 +334,7 @@ def get_args():
316334
parser.add_argument("--tests", help="tests to run")
317335
parser.add_argument("--target", default="", help="target platform")
318336
parser.add_argument("--backend", default="onnxruntime",
319-
choices=["caffe2", "onnxmsrtnext", "onnxruntime"], help="backend to use")
337+
choices=["caffe2", "onnxruntime"], help="backend to use")
320338
parser.add_argument("--opset", type=int, default=None, help="opset to use")
321339
parser.add_argument("--extra_opset", default=None,
322340
help="extra opset with format like domain:version, e.g. com.microsoft:1")
@@ -339,21 +357,57 @@ def get_args():
339357
return args
340358

341359

342-
def tests_from_yaml(fname):
360+
def load_tests_from_yaml(path):
343361
"""Create test class from yaml file."""
362+
path = os.path.abspath(path)
363+
base_dir = os.path.dirname(path)
364+
344365
tests = {}
345-
config = yaml.load(open(fname, 'r').read())
346-
for k, v in config.items():
347-
input_func = v.get("input_get")
366+
config = yaml.safe_load(open(path, 'r').read())
367+
for name, settings in config.items():
368+
if name in tests:
369+
raise ValueError("Found duplicated test: {}".format(name))
370+
371+
# parse model and url, non-absolute local path is relative to yaml directory
372+
model = settings.get("model")
373+
url = settings.get("url")
374+
if not url and not os.path.isabs(model):
375+
model = os.path.join(base_dir, model)
376+
377+
# parse input_get
378+
input_func = settings.get("input_get")
348379
input_func = _INPUT_FUNC_MAPPING[input_func]
380+
381+
# parse inputs, non-absolute npy file path for np.load is relative to yaml directory
382+
inputs = settings.get("inputs")
383+
for k, v in list(inputs.items()):
384+
if isinstance(v, str):
385+
# assume at most 1 match
386+
matches = re.findall(r"np\.load\((r?['\"].*?['\"])", v)
387+
if matches:
388+
npy_path = matches[0].lstrip('r').strip("'").strip('"')
389+
if not os.path.isabs(npy_path):
390+
abs_npy_path = os.path.join(base_dir, npy_path)
391+
inputs[k] = v.replace(matches[0], "r'{}'".format(abs_npy_path))
392+
393+
# parse opset_constraints
394+
opset_constraints = []
395+
section = settings.get("opset_constraints")
396+
if section:
397+
for k, v in section.items():
398+
c = OpsetConstraint(k, min_version=v.get("min"), max_version=v.get("max"),
399+
excluded_version=v.get("excluded"))
400+
opset_constraints.append(c)
401+
349402
kwargs = {}
350-
for kw in ["rtol", "atol", "disabled", "more_inputs", "check_only_shape", "model_type",
403+
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type",
351404
"skip_tensorflow", "force_input_shape"]:
352-
if v.get(kw) is not None:
353-
kwargs[kw] = v[kw]
405+
if settings.get(kw) is not None:
406+
kwargs[kw] = settings[kw]
354407

355-
test = Test(v.get("url"), v.get("model"), input_func, v.get("inputs"), v.get("outputs"), **kwargs)
356-
tests[k] = test
408+
test = Test(url, model, input_func, inputs, settings.get("outputs"),
409+
opset_constraints=opset_constraints, **kwargs)
410+
tests[name] = test
357411
return tests
358412

359413

@@ -365,7 +419,7 @@ def main():
365419

366420
Test.cache_dir = args.cache
367421
Test.target = args.target
368-
tests = tests_from_yaml(args.config)
422+
tests = load_tests_from_yaml(args.config)
369423
if args.list:
370424
logger.info(sorted(tests.keys()))
371425
return 0
@@ -377,11 +431,22 @@ def main():
377431
failed = 0
378432
count = 0
379433
for test in test_keys:
434+
logger.info("===================================")
435+
380436
t = tests[test]
381-
if args.tests is None and t.disabled and not args.include_disabled:
382-
continue
437+
if args.tests is None:
438+
if t.disabled and not args.include_disabled:
439+
logger.info("Skip %s: disabled", test)
440+
continue
441+
442+
condition, reason = t.check_opset_constraints(args.opset, args.extra_opset)
443+
if not condition:
444+
logger.info("Skip %s: %s", test, reason)
445+
continue
446+
383447
count += 1
384448
try:
449+
logger.info("Running %s", test)
385450
ret = t.run_test(test, backend=args.backend, onnx_file=args.onnx_file,
386451
opset=args.opset, extra_opset=args.extra_opset, perf=args.perf,
387452
fold_const=args.fold_const)

0 commit comments

Comments
 (0)