15
15
import tarfile
16
16
import time
17
17
import zipfile
18
+ from collections import namedtuple
18
19
19
20
import PIL .Image
20
21
import numpy as np
@@ -71,6 +72,8 @@ def get_ramp(shape):
71
72
"get_ramp" : get_ramp
72
73
}
73
74
75
+ OpsetConstraint = namedtuple ("OpsetConstraint" , "domain, min_version, max_version, excluded_version" )
76
+
74
77
75
78
class Test (object ):
76
79
"""Main Test class."""
@@ -81,7 +84,7 @@ class Test(object):
81
84
def __init__ (self , url , local , make_input , input_names , output_names ,
82
85
disabled = False , more_inputs = None , rtol = 0.01 , atol = 1e-6 ,
83
86
check_only_shape = False , model_type = "frozen" , force_input_shape = False ,
84
- skip_tensorflow = False ):
87
+ skip_tensorflow = False , opset_constraints = None ):
85
88
self .url = url
86
89
self .make_input = make_input
87
90
self .local = local
@@ -98,6 +101,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
98
101
self .model_type = model_type
99
102
self .force_input_shape = force_input_shape
100
103
self .skip_tensorflow = skip_tensorflow
104
+ self .opset_constraints = opset_constraints
101
105
102
106
def download_model (self ):
103
107
"""Download model from url."""
@@ -188,8 +192,6 @@ def create_onnx_file(name, model_proto, inputs, outdir):
188
192
def run_test (self , name , backend = "caffe2" , onnx_file = None , opset = None , extra_opset = None ,
189
193
perf = None , fold_const = None ):
190
194
"""Run complete test against backend."""
191
- logger .info ("===================================" )
192
- logger .info ("Running %s" , name )
193
195
self .perf = perf
194
196
195
197
# get the model
@@ -293,6 +295,41 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
293
295
294
296
return False
295
297
298
+ def check_opset_constraints (self , opset , extra_opset = None ):
299
+ """ Return (condition, reason) tuple, condition is True if constraints are met. """
300
+ if not self .opset_constraints :
301
+ return True , None
302
+
303
+ opsets = {"onnx" : opset }
304
+ if extra_opset :
305
+ for e in extra_opset :
306
+ opsets [e .domain ] = e .version
307
+
308
+ for constraint in self .opset_constraints :
309
+ domain = constraint .domain
310
+ opset_version = opsets .get (domain )
311
+ if not opset_version :
312
+ return False , "conversion requires opset {}" .format (domain )
313
+
314
+ if constraint .min_version and opset_version < constraint .min_version :
315
+ reason = "conversion requires opset {} >= {}" .format (domain , constraint .min_version )
316
+ return False , reason
317
+
318
+ if constraint .max_version and opset_version > constraint .max_version :
319
+ reason = "conversion requires opset {} <= {}" .format (domain , constraint .max_version )
320
+ return False , reason
321
+
322
+ if constraint .excluded_version :
323
+ if utils .is_list_or_tuple (constraint .excluded_version ):
324
+ skip = opset_version in constraint .excluded_version
325
+ else :
326
+ skip = opset_version == constraint .excluded_version
327
+ if skip :
328
+ reason = "conversion requires opset {} != {}" .format (domain , constraint .excluded_version )
329
+ return False , reason
330
+
331
+ return True , None
332
+
296
333
297
334
def get_args ():
298
335
"""Parse commandline."""
@@ -325,44 +362,57 @@ def get_args():
325
362
return args
326
363
327
364
328
- def tests_from_yaml (path ):
365
+ def load_tests_from_yaml (path ):
329
366
"""Create test class from yaml file."""
330
367
path = os .path .abspath (path )
331
368
base_dir = os .path .dirname (path )
332
369
333
370
tests = {}
334
- config = yaml .load (open (path , 'r' ).read ())
335
- for k , v in config .items ():
336
- input_func = v .get ("input_get" )
371
+ config = yaml .safe_load (open (path , 'r' ).read ())
372
+ for name , settings in config .items ():
373
+ if name in tests :
374
+ raise ValueError ("Found duplicated test: {}" .format (name ))
375
+
376
+ # parse model and url, non-absolute local path is relative to yaml directory
377
+ model = settings .get ("model" )
378
+ url = settings .get ("url" )
379
+ if not url and not os .path .isabs (model ):
380
+ model = os .path .join (base_dir , model )
381
+
382
+ # parse input_get
383
+ input_func = settings .get ("input_get" )
337
384
input_func = _INPUT_FUNC_MAPPING [input_func ]
338
- kwargs = {}
339
- for kw in ["rtol" , "atol" , "disabled" , "more_inputs" , "check_only_shape" , "model_type" ,
340
- "skip_tensorflow" , "force_input_shape" ]:
341
- if v .get (kw ) is not None :
342
- kwargs [kw ] = v [kw ]
343
-
344
- # when model is local, non-absolute path is relative to yaml directory
345
- url = v .get ("url" )
346
- model = v .get ("model" )
347
- if not url :
348
- if not os .path .isabs (model ):
349
- model = os .path .join (base_dir , model )
350
-
351
- # non-absolute npy file path for np.load is relative to yaml directory
352
- input_names = v .get ("inputs" )
353
- for key in list (input_names .keys ()):
354
- value = input_names [key ]
355
- if isinstance (value , str ):
385
+
386
+ # parse inputs, non-absolute npy file path for np.load is relative to yaml directory
387
+ inputs = settings .get ("inputs" )
388
+ for k , v in list (inputs .items ()):
389
+ if isinstance (v , str ):
356
390
# assume at most 1 match
357
- matches = re .findall (r"np\.load\((r?['\"].*?['\"])" , value )
391
+ matches = re .findall (r"np\.load\((r?['\"].*?['\"])" , v )
358
392
if matches :
359
393
npy_path = matches [0 ].lstrip ('r' ).strip ("'" ).strip ('"' )
360
394
if not os .path .isabs (npy_path ):
361
395
abs_npy_path = os .path .join (base_dir , npy_path )
362
- input_names [key ] = value .replace (matches [0 ], "r'{}'" .format (abs_npy_path ))
396
+ inputs [k ] = v .replace (matches [0 ], "r'{}'" .format (abs_npy_path ))
397
+
398
+ # parse opset_constraints
399
+ opset_constraints = []
400
+ section = settings .get ("opset_constraints" )
401
+ if section :
402
+ for k , v in section .items ():
403
+ c = OpsetConstraint (k , min_version = v .get ("min" ), max_version = v .get ("max" ),
404
+ excluded_version = v .get ("excluded" ))
405
+ opset_constraints .append (c )
363
406
364
- test = Test (url , model , input_func , input_names , v .get ("outputs" ), ** kwargs )
365
- tests [k ] = test
407
+ kwargs = {}
408
+ for kw in ["rtol" , "atol" , "disabled" , "more_inputs" , "check_only_shape" , "model_type" ,
409
+ "skip_tensorflow" , "force_input_shape" ]:
410
+ if settings .get (kw ) is not None :
411
+ kwargs [kw ] = settings [kw ]
412
+
413
+ test = Test (url , model , input_func , inputs , settings .get ("outputs" ),
414
+ opset_constraints = opset_constraints , ** kwargs )
415
+ tests [name ] = test
366
416
return tests
367
417
368
418
@@ -374,7 +424,7 @@ def main():
374
424
375
425
Test .cache_dir = args .cache
376
426
Test .target = args .target
377
- tests = tests_from_yaml (args .config )
427
+ tests = load_tests_from_yaml (args .config )
378
428
if args .list :
379
429
logger .info (sorted (tests .keys ()))
380
430
return 0
@@ -386,11 +436,22 @@ def main():
386
436
failed = 0
387
437
count = 0
388
438
for test in test_keys :
439
+ logger .info ("===================================" )
440
+
389
441
t = tests [test ]
390
- if args .tests is None and t .disabled and not args .include_disabled :
391
- continue
442
+ if args .tests is None :
443
+ if t .disabled and not args .include_disabled :
444
+ logger .info ("Skip %s: disabled" , test )
445
+ continue
446
+
447
+ condition , reason = t .check_opset_constraints (args .opset , args .extra_opset )
448
+ if not condition :
449
+ logger .info ("Skip %s: %s" , test , reason )
450
+ continue
451
+
392
452
count += 1
393
453
try :
454
+ logger .info ("Running %s" , test )
394
455
ret = t .run_test (test , backend = args .backend , onnx_file = args .onnx_file ,
395
456
opset = args .opset , extra_opset = args .extra_opset , perf = args .perf ,
396
457
fold_const = args .fold_const )
0 commit comments