Skip to content

Commit 599f223

Browse files
committed
opset_constraints
1 parent a66c448 commit 599f223

File tree

1 file changed

+93
-32
lines changed

1 file changed

+93
-32
lines changed

tests/run_pretrained_models.py

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import tarfile
1616
import time
1717
import zipfile
18+
from collections import namedtuple
1819

1920
import PIL.Image
2021
import numpy as np
@@ -71,6 +72,8 @@ def get_ramp(shape):
7172
"get_ramp": get_ramp
7273
}
7374

75+
OpsetConstraint = namedtuple("OpsetConstraint", "domain, min_version, max_version, excluded_version")
76+
7477

7578
class Test(object):
7679
"""Main Test class."""
@@ -81,7 +84,7 @@ class Test(object):
8184
def __init__(self, url, local, make_input, input_names, output_names,
8285
disabled=False, more_inputs=None, rtol=0.01, atol=1e-6,
8386
check_only_shape=False, model_type="frozen", force_input_shape=False,
84-
skip_tensorflow=False):
87+
skip_tensorflow=False, opset_constraints=None):
8588
self.url = url
8689
self.make_input = make_input
8790
self.local = local
@@ -98,6 +101,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
98101
self.model_type = model_type
99102
self.force_input_shape = force_input_shape
100103
self.skip_tensorflow = skip_tensorflow
104+
self.opset_constraints = opset_constraints
101105

102106
def download_model(self):
103107
"""Download model from url."""
@@ -188,8 +192,6 @@ def create_onnx_file(name, model_proto, inputs, outdir):
188192
def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_opset=None,
189193
perf=None, fold_const=None):
190194
"""Run complete test against backend."""
191-
logger.info("===================================")
192-
logger.info("Running %s", name)
193195
self.perf = perf
194196

195197
# get the model
@@ -293,6 +295,41 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
293295

294296
return False
295297

298+
def check_opset_constraints(self, opset, extra_opset=None):
299+
""" Return (condition, reason) tuple, condition is True if constraints are met. """
300+
if not self.opset_constraints:
301+
return True, None
302+
303+
opsets = {"onnx": opset}
304+
if extra_opset:
305+
for e in extra_opset:
306+
opsets[e.domain] = e.version
307+
308+
for constraint in self.opset_constraints:
309+
domain = constraint.domain
310+
opset_version = opsets.get(domain)
311+
if not opset_version:
312+
return False, "conversion requires opset {}".format(domain)
313+
314+
if constraint.min_version and opset_version < constraint.min_version:
315+
reason = "conversion requires opset {} >= {}".format(domain, constraint.min_version)
316+
return False, reason
317+
318+
if constraint.max_version and opset_version > constraint.max_version:
319+
reason = "conversion requires opset {} <= {}".format(domain, constraint.max_version)
320+
return False, reason
321+
322+
if constraint.excluded_version:
323+
if utils.is_list_or_tuple(constraint.excluded_version):
324+
skip = opset_version in constraint.excluded_version
325+
else:
326+
skip = opset_version == constraint.excluded_version
327+
if skip:
328+
reason = "conversion requires opset {} != {}".format(domain, constraint.excluded_version)
329+
return False, reason
330+
331+
return True, None
332+
296333

297334
def get_args():
298335
"""Parse commandline."""
@@ -325,44 +362,57 @@ def get_args():
325362
return args
326363

327364

328-
def tests_from_yaml(path):
365+
def load_tests_from_yaml(path):
329366
"""Create test class from yaml file."""
330367
path = os.path.abspath(path)
331368
base_dir = os.path.dirname(path)
332369

333370
tests = {}
334-
config = yaml.load(open(path, 'r').read())
335-
for k, v in config.items():
336-
input_func = v.get("input_get")
371+
config = yaml.safe_load(open(path, 'r').read())
372+
for name, settings in config.items():
373+
if name in tests:
374+
raise ValueError("Found duplicated test: {}".format(name))
375+
376+
# parse model and url, non-absolute local path is relative to yaml directory
377+
model = settings.get("model")
378+
url = settings.get("url")
379+
if not url and not os.path.isabs(model):
380+
model = os.path.join(base_dir, model)
381+
382+
# parse input_get
383+
input_func = settings.get("input_get")
337384
input_func = _INPUT_FUNC_MAPPING[input_func]
338-
kwargs = {}
339-
for kw in ["rtol", "atol", "disabled", "more_inputs", "check_only_shape", "model_type",
340-
"skip_tensorflow", "force_input_shape"]:
341-
if v.get(kw) is not None:
342-
kwargs[kw] = v[kw]
343-
344-
# when model is local, non-absolute path is relative to yaml directory
345-
url = v.get("url")
346-
model = v.get("model")
347-
if not url:
348-
if not os.path.isabs(model):
349-
model = os.path.join(base_dir, model)
350-
351-
# non-absolute npy file path for np.load is relative to yaml directory
352-
input_names = v.get("inputs")
353-
for key in list(input_names.keys()):
354-
value = input_names[key]
355-
if isinstance(value, str):
385+
386+
# parse inputs, non-absolute npy file path for np.load is relative to yaml directory
387+
inputs = settings.get("inputs")
388+
for k, v in list(inputs.items()):
389+
if isinstance(v, str):
356390
# assume at most 1 match
357-
matches = re.findall(r"np\.load\((r?['\"].*?['\"])", value)
391+
matches = re.findall(r"np\.load\((r?['\"].*?['\"])", v)
358392
if matches:
359393
npy_path = matches[0].lstrip('r').strip("'").strip('"')
360394
if not os.path.isabs(npy_path):
361395
abs_npy_path = os.path.join(base_dir, npy_path)
362-
input_names[key] = value.replace(matches[0], "r'{}'".format(abs_npy_path))
396+
inputs[k] = v.replace(matches[0], "r'{}'".format(abs_npy_path))
397+
398+
# parse opset_constraints
399+
opset_constraints = []
400+
section = settings.get("opset_constraints")
401+
if section:
402+
for k, v in section.items():
403+
c = OpsetConstraint(k, min_version=v.get("min"), max_version=v.get("max"),
404+
excluded_version=v.get("excluded"))
405+
opset_constraints.append(c)
363406

364-
test = Test(url, model, input_func, input_names, v.get("outputs"), **kwargs)
365-
tests[k] = test
407+
kwargs = {}
408+
for kw in ["rtol", "atol", "disabled", "more_inputs", "check_only_shape", "model_type",
409+
"skip_tensorflow", "force_input_shape"]:
410+
if settings.get(kw) is not None:
411+
kwargs[kw] = settings[kw]
412+
413+
test = Test(url, model, input_func, inputs, settings.get("outputs"),
414+
opset_constraints=opset_constraints, **kwargs)
415+
tests[name] = test
366416
return tests
367417

368418

@@ -374,7 +424,7 @@ def main():
374424

375425
Test.cache_dir = args.cache
376426
Test.target = args.target
377-
tests = tests_from_yaml(args.config)
427+
tests = load_tests_from_yaml(args.config)
378428
if args.list:
379429
logger.info(sorted(tests.keys()))
380430
return 0
@@ -386,11 +436,22 @@ def main():
386436
failed = 0
387437
count = 0
388438
for test in test_keys:
439+
logger.info("===================================")
440+
389441
t = tests[test]
390-
if args.tests is None and t.disabled and not args.include_disabled:
391-
continue
442+
if args.tests is None:
443+
if t.disabled and not args.include_disabled:
444+
logger.info("Skip %s: disabled", test)
445+
continue
446+
447+
condition, reason = t.check_opset_constraints(args.opset, args.extra_opset)
448+
if not condition:
449+
logger.info("Skip %s: %s", test, reason)
450+
continue
451+
392452
count += 1
393453
try:
454+
logger.info("Running %s", test)
394455
ret = t.run_test(test, backend=args.backend, onnx_file=args.onnx_file,
395456
opset=args.opset, extra_opset=args.extra_opset, perf=args.perf,
396457
fold_const=args.fold_const)

0 commit comments

Comments
 (0)