10
10
11
11
import argparse
12
12
import os
13
+ import re
13
14
import sys
14
15
import tarfile
15
16
import time
16
17
import zipfile
18
+ from collections import namedtuple
17
19
18
20
import PIL .Image
19
21
import numpy as np
@@ -70,6 +72,8 @@ def get_ramp(shape):
70
72
"get_ramp" : get_ramp
71
73
}
72
74
75
+ OpsetConstraint = namedtuple ("OpsetConstraint" , "domain, min_version, max_version, excluded_version" )
76
+
73
77
74
78
class Test (object ):
75
79
"""Main Test class."""
@@ -78,16 +82,15 @@ class Test(object):
78
82
target = []
79
83
80
84
def __init__ (self , url , local , make_input , input_names , output_names ,
81
- disabled = False , more_inputs = None , rtol = 0.01 , atol = 1e-6 ,
85
+ disabled = False , rtol = 0.01 , atol = 1e-6 ,
82
86
check_only_shape = False , model_type = "frozen" , force_input_shape = False ,
83
- skip_tensorflow = False ):
87
+ skip_tensorflow = False , opset_constraints = None ):
84
88
self .url = url
85
89
self .make_input = make_input
86
90
self .local = local
87
91
self .input_names = input_names
88
92
self .output_names = output_names
89
93
self .disabled = disabled
90
- self .more_inputs = more_inputs
91
94
self .rtol = rtol
92
95
self .atol = atol
93
96
self .check_only_shape = check_only_shape
@@ -97,9 +100,10 @@ def __init__(self, url, local, make_input, input_names, output_names,
97
100
self .model_type = model_type
98
101
self .force_input_shape = force_input_shape
99
102
self .skip_tensorflow = skip_tensorflow
103
+ self .opset_constraints = opset_constraints
100
104
101
- def download_file (self ):
102
- """Download file from url."""
105
+ def download_model (self ):
106
+ """Download model from url."""
103
107
cache_dir = Test .cache_dir
104
108
if not os .path .exists (cache_dir ):
105
109
os .makedirs (cache_dir )
@@ -163,21 +167,8 @@ def run_caffe2(self, name, model_proto, inputs):
163
167
self .onnx_runtime = time .time () - start
164
168
return results
165
169
166
- def run_onnxmsrtnext (self , name , model_proto , inputs ):
167
- """Run test against msrt-next backend."""
168
- import lotus
169
- model_path = utils .save_onnx_model (TEMP_DIR , name , inputs , model_proto )
170
- m = lotus .InferenceSession (model_path )
171
- results = m .run (self .output_names , inputs )
172
- if self .perf :
173
- start = time .time ()
174
- for _ in range (PERFITER ):
175
- _ = m .run (self .output_names , inputs )
176
- self .onnx_runtime = time .time () - start
177
- return results
178
-
179
170
def run_onnxruntime (self , name , model_proto , inputs ):
180
- """Run test against msrt-next backend."""
171
+ """Run test against onnxruntime backend."""
181
172
import onnxruntime as rt
182
173
model_path = utils .save_onnx_model (TEMP_DIR , name , inputs , model_proto , include_test_data = True )
183
174
logger .info ("Model saved to %s" , model_path )
@@ -200,19 +191,17 @@ def create_onnx_file(name, model_proto, inputs, outdir):
200
191
def run_test (self , name , backend = "caffe2" , onnx_file = None , opset = None , extra_opset = None ,
201
192
perf = None , fold_const = None ):
202
193
"""Run complete test against backend."""
203
- logger .info ("===================================" )
204
- logger .info ("Running %s" , name )
205
194
self .perf = perf
206
195
207
196
# get the model
208
197
if self .url :
209
- _ , dir_name = self .download_file ()
198
+ _ , dir_name = self .download_model ()
199
+ logger .info ("Downloaded to %s" , dir_name )
210
200
model_path = os .path .join (dir_name , self .local )
211
201
else :
212
202
model_path = self .local
213
- dir_name = os .path .dirname (self .local )
214
- logger .info ("Downloaded to %s" , model_path )
215
203
204
+ logger .info ("Load model from %s" , model_path )
216
205
input_names = list (self .input_names .keys ())
217
206
outputs = self .output_names
218
207
if self .model_type in ["checkpoint" ]:
@@ -222,34 +211,30 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
222
211
else :
223
212
graph_def , input_names , outputs = loader .from_graphdef (model_path , input_names , outputs )
224
213
225
- # create the input data
226
- inputs = {}
227
- for k , v in self .input_names .items ():
228
- if k not in input_names :
229
- continue
230
- if isinstance (v , six .text_type ) and v .startswith ("np." ):
231
- inputs [k ] = eval (v ) # pylint: disable=eval-used
232
- else :
233
- inputs [k ] = self .make_input (v )
234
- if self .more_inputs :
235
- for k , v in self .more_inputs .items ():
236
- inputs [k ] = v
237
-
238
- graph_def = tf2onnx .tfonnx .tf_optimize (inputs .keys (), self .output_names , graph_def , fold_const )
214
+ # remove unused input names
215
+ input_names = list (set (input_names ).intersection (self .input_names .keys ()))
216
+ graph_def = tf2onnx .tfonnx .tf_optimize (input_names , self .output_names , graph_def , fold_const )
239
217
if utils .is_debug_mode ():
240
218
utils .save_protobuf (os .path .join (TEMP_DIR , name + "_after_tf_optimize.pb" ), graph_def )
219
+
220
+ inputs = {}
241
221
shape_override = {}
242
222
g = tf .import_graph_def (graph_def , name = '' )
243
223
with tf .Session (config = tf .ConfigProto (allow_soft_placement = True ), graph = g ) as sess :
244
-
245
- # fix inputs if needed
246
- for k in inputs . keys (): # pylint: disable=consider-iterating-dictionary
224
+ # create the input data
225
+ for k in input_names :
226
+ v = self . input_names [ k ]
247
227
t = sess .graph .get_tensor_by_name (k )
248
- dtype = tf .as_dtype (t .dtype ).name
249
- v = inputs [k ]
250
- if dtype != v .dtype :
251
- logger .warning ("input dtype doesn't match tensorflow's" )
252
- inputs [k ] = np .array (v , dtype = dtype )
228
+ expected_dtype = tf .as_dtype (t .dtype ).name
229
+ if isinstance (v , six .text_type ) and v .startswith ("np." ):
230
+ np_value = eval (v ) # pylint: disable=eval-used
231
+ if expected_dtype != np_value .dtype :
232
+ logger .warning ("dtype mismatch for input %s: expected=%s, actual=%s" , k , expected_dtype ,
233
+ np_value .dtype )
234
+ inputs [k ] = np_value .astype (expected_dtype )
235
+ else :
236
+ inputs [k ] = self .make_input (v ).astype (expected_dtype )
237
+
253
238
if self .force_input_shape :
254
239
for k , v in inputs .items ():
255
240
shape_override [k ] = list (v .shape )
@@ -279,8 +264,6 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
279
264
onnx_results = None
280
265
if backend == "caffe2" :
281
266
onnx_results = self .run_caffe2 (name , model_proto , inputs )
282
- elif backend == "onnxmsrtnext" :
283
- onnx_results = self .run_onnxmsrtnext (name , model_proto , inputs )
284
267
elif backend == "onnxruntime" :
285
268
onnx_results = self .run_onnxruntime (name , model_proto , inputs )
286
269
else :
@@ -307,6 +290,41 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
307
290
308
291
return False
309
292
293
+ def check_opset_constraints (self , opset , extra_opset = None ):
294
+ """ Return (condition, reason) tuple, condition is True if constraints are met. """
295
+ if not self .opset_constraints :
296
+ return True , None
297
+
298
+ opsets = {"onnx" : opset }
299
+ if extra_opset :
300
+ for e in extra_opset :
301
+ opsets [e .domain ] = e .version
302
+
303
+ for constraint in self .opset_constraints :
304
+ domain = constraint .domain
305
+ opset_version = opsets .get (domain )
306
+ if not opset_version :
307
+ return False , "conversion requires opset {}" .format (domain )
308
+
309
+ if constraint .min_version and opset_version < constraint .min_version :
310
+ reason = "conversion requires opset {} >= {}" .format (domain , constraint .min_version )
311
+ return False , reason
312
+
313
+ if constraint .max_version and opset_version > constraint .max_version :
314
+ reason = "conversion requires opset {} <= {}" .format (domain , constraint .max_version )
315
+ return False , reason
316
+
317
+ if constraint .excluded_version :
318
+ if utils .is_list_or_tuple (constraint .excluded_version ):
319
+ skip = opset_version in constraint .excluded_version
320
+ else :
321
+ skip = opset_version == constraint .excluded_version
322
+ if skip :
323
+ reason = "conversion requires opset {} != {}" .format (domain , constraint .excluded_version )
324
+ return False , reason
325
+
326
+ return True , None
327
+
310
328
311
329
def get_args ():
312
330
"""Parse commandline."""
@@ -316,7 +334,7 @@ def get_args():
316
334
parser .add_argument ("--tests" , help = "tests to run" )
317
335
parser .add_argument ("--target" , default = "" , help = "target platform" )
318
336
parser .add_argument ("--backend" , default = "onnxruntime" ,
319
- choices = ["caffe2" , "onnxmsrtnext" , " onnxruntime" ], help = "backend to use" )
337
+ choices = ["caffe2" , "onnxruntime" ], help = "backend to use" )
320
338
parser .add_argument ("--opset" , type = int , default = None , help = "opset to use" )
321
339
parser .add_argument ("--extra_opset" , default = None ,
322
340
help = "extra opset with format like domain:version, e.g. com.microsoft:1" )
@@ -339,21 +357,57 @@ def get_args():
339
357
return args
340
358
341
359
342
- def tests_from_yaml ( fname ):
360
+ def load_tests_from_yaml ( path ):
343
361
"""Create test class from yaml file."""
362
+ path = os .path .abspath (path )
363
+ base_dir = os .path .dirname (path )
364
+
344
365
tests = {}
345
- config = yaml .load (open (fname , 'r' ).read ())
346
- for k , v in config .items ():
347
- input_func = v .get ("input_get" )
366
+ config = yaml .safe_load (open (path , 'r' ).read ())
367
+ for name , settings in config .items ():
368
+ if name in tests :
369
+ raise ValueError ("Found duplicated test: {}" .format (name ))
370
+
371
+ # parse model and url, non-absolute local path is relative to yaml directory
372
+ model = settings .get ("model" )
373
+ url = settings .get ("url" )
374
+ if not url and not os .path .isabs (model ):
375
+ model = os .path .join (base_dir , model )
376
+
377
+ # parse input_get
378
+ input_func = settings .get ("input_get" )
348
379
input_func = _INPUT_FUNC_MAPPING [input_func ]
380
+
381
+ # parse inputs, non-absolute npy file path for np.load is relative to yaml directory
382
+ inputs = settings .get ("inputs" )
383
+ for k , v in list (inputs .items ()):
384
+ if isinstance (v , str ):
385
+ # assume at most 1 match
386
+ matches = re .findall (r"np\.load\((r?['\"].*?['\"])" , v )
387
+ if matches :
388
+ npy_path = matches [0 ].lstrip ('r' ).strip ("'" ).strip ('"' )
389
+ if not os .path .isabs (npy_path ):
390
+ abs_npy_path = os .path .join (base_dir , npy_path )
391
+ inputs [k ] = v .replace (matches [0 ], "r'{}'" .format (abs_npy_path ))
392
+
393
+ # parse opset_constraints
394
+ opset_constraints = []
395
+ section = settings .get ("opset_constraints" )
396
+ if section :
397
+ for k , v in section .items ():
398
+ c = OpsetConstraint (k , min_version = v .get ("min" ), max_version = v .get ("max" ),
399
+ excluded_version = v .get ("excluded" ))
400
+ opset_constraints .append (c )
401
+
349
402
kwargs = {}
350
- for kw in ["rtol" , "atol" , "disabled" , "more_inputs" , " check_only_shape" , "model_type" ,
403
+ for kw in ["rtol" , "atol" , "disabled" , "check_only_shape" , "model_type" ,
351
404
"skip_tensorflow" , "force_input_shape" ]:
352
- if v .get (kw ) is not None :
353
- kwargs [kw ] = v [kw ]
405
+ if settings .get (kw ) is not None :
406
+ kwargs [kw ] = settings [kw ]
354
407
355
- test = Test (v .get ("url" ), v .get ("model" ), input_func , v .get ("inputs" ), v .get ("outputs" ), ** kwargs )
356
- tests [k ] = test
408
+ test = Test (url , model , input_func , inputs , settings .get ("outputs" ),
409
+ opset_constraints = opset_constraints , ** kwargs )
410
+ tests [name ] = test
357
411
return tests
358
412
359
413
@@ -365,7 +419,7 @@ def main():
365
419
366
420
Test .cache_dir = args .cache
367
421
Test .target = args .target
368
- tests = tests_from_yaml (args .config )
422
+ tests = load_tests_from_yaml (args .config )
369
423
if args .list :
370
424
logger .info (sorted (tests .keys ()))
371
425
return 0
@@ -377,11 +431,22 @@ def main():
377
431
failed = 0
378
432
count = 0
379
433
for test in test_keys :
434
+ logger .info ("===================================" )
435
+
380
436
t = tests [test ]
381
- if args .tests is None and t .disabled and not args .include_disabled :
382
- continue
437
+ if args .tests is None :
438
+ if t .disabled and not args .include_disabled :
439
+ logger .info ("Skip %s: disabled" , test )
440
+ continue
441
+
442
+ condition , reason = t .check_opset_constraints (args .opset , args .extra_opset )
443
+ if not condition :
444
+ logger .info ("Skip %s: %s" , test , reason )
445
+ continue
446
+
383
447
count += 1
384
448
try :
449
+ logger .info ("Running %s" , test )
385
450
ret = t .run_test (test , backend = args .backend , onnx_file = args .onnx_file ,
386
451
opset = args .opset , extra_opset = args .extra_opset , perf = args .perf ,
387
452
fold_const = args .fold_const )
0 commit comments