56
56
PERFITER = 1000
57
57
58
58
59
- def get_beach (shape ):
60
- """Get beach image as input."""
59
+ def get_img (shape , path , dtype , should_scale = True ):
60
+ """Get image as input."""
61
61
resize_to = shape [1 :3 ]
62
- path = os .path .join (os .path .dirname (os .path .abspath (__file__ )), "beach.jpg" )
62
+ path = os .path .join (os .path .dirname (os .path .abspath (__file__ )), path )
63
63
img = PIL .Image .open (path )
64
64
img = img .resize (resize_to , PIL .Image .ANTIALIAS )
65
- img_np = np .array (img ).astype (np . float32 )
65
+ img_np = np .array (img ).astype (dtype )
66
66
img_np = np .stack ([img_np ] * shape [0 ], axis = 0 ).reshape (shape )
67
- return img_np / 255
67
+ if should_scale :
68
+ img_np = img_np / 255
69
+ return img_np
70
+
71
+
72
+ def get_beach (shape ):
73
+ """Get beach image as input."""
74
+ return get_img (shape , "beach.jpg" , np .float32 , should_scale = True )
75
+
76
+
77
+ def get_car (shape ):
78
+ """Get car image as input."""
79
+ return get_img (shape , "car.JPEG" , np .float32 , should_scale = True )
68
80
69
81
70
82
def get_random (shape ):
@@ -133,6 +145,7 @@ def get_sentence():
133
145
134
146
_INPUT_FUNC_MAPPING = {
135
147
"get_beach" : get_beach ,
148
+ "get_car" : get_car ,
136
149
"get_random" : get_random ,
137
150
"get_random256" : get_random256 ,
138
151
"get_ramp" : get_ramp ,
@@ -219,6 +232,9 @@ def download_model(self):
219
232
elif url .endswith ('.zip' ):
220
233
ftype = 'zip'
221
234
dir_name = fname .replace (".zip" , "" )
235
+ elif url .endswith ('.tflite' ):
236
+ ftype = 'tflite'
237
+ dir_name = fname .replace (".tflite" , "" )
222
238
dir_name = os .path .join (cache_dir , dir_name )
223
239
os .makedirs (dir_name , exist_ok = True )
224
240
fpath = os .path .join (dir_name , fname )
@@ -266,7 +282,7 @@ def run_tensorflow(self, sess, inputs):
266
282
return result
267
283
268
284
def to_onnx (self , tf_graph , opset = None , extra_opset = None , shape_override = None , input_names = None ,
269
- const_node_values = None , initialized_tables = None ):
285
+ const_node_values = None , initialized_tables = None , tflite_path = None ):
270
286
"""Convert graph to tensorflow."""
271
287
if extra_opset is None :
272
288
extra_opset = []
@@ -275,8 +291,8 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
275
291
return process_tf_graph (tf_graph , continue_on_error = False , opset = opset ,
276
292
extra_opset = extra_opset , target = Test .target , shape_override = shape_override ,
277
293
input_names = input_names , output_names = self .output_names ,
278
- const_node_values = const_node_values ,
279
- initialized_tables = initialized_tables )
294
+ const_node_values = const_node_values , initialized_tables = initialized_tables ,
295
+ tflite_path = tflite_path )
280
296
281
297
def run_caffe2 (self , name , model_proto , inputs ):
282
298
"""Run test again caffe2 backend."""
@@ -340,6 +356,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
340
356
input_names = list (self .input_names .keys ())
341
357
initialized_tables = {}
342
358
outputs = self .output_names
359
+ tflite_path = None
343
360
if self .model_type in ["checkpoint" ]:
344
361
graph_def , input_names , outputs = tf_loader .from_checkpoint (model_path , input_names , outputs )
345
362
elif self .model_type in ["saved_model" ]:
@@ -355,12 +372,43 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
355
372
graph_def , input_names , outputs , initialized_tables = loaded
356
373
elif self .model_type in ["keras" ]:
357
374
graph_def , input_names , outputs = tf_loader .from_keras (model_path , input_names , outputs )
375
+ elif self .model_type in ["tflite" ]:
376
+ tflite_path = model_path
377
+ graph_def = None
358
378
else :
359
379
graph_def , input_names , outputs = tf_loader .from_graphdef (model_path , input_names , outputs )
360
380
361
381
if utils .is_debug_mode ():
362
382
utils .save_protobuf (os .path .join (TEMP_DIR , name + "_after_tf_optimize.pb" ), graph_def )
363
383
384
+ if tflite_path is not None :
385
+ inputs = {}
386
+ for k in input_names :
387
+ v = self .input_names [k ]
388
+ inputs [k ] = self .make_input (v )
389
+
390
+ interpreter = tf .lite .Interpreter (tflite_path )
391
+ input_details = interpreter .get_input_details ()
392
+ output_details = interpreter .get_output_details ()
393
+ input_name_to_index = {n ['name' ].split (':' )[0 ]: n ['index' ] for n in input_details }
394
+ for k , v in inputs .items ():
395
+ interpreter .resize_tensor_input (input_name_to_index [k ], v .shape )
396
+ interpreter .allocate_tensors ()
397
+ def run_tflite ():
398
+ for k , v in inputs .items ():
399
+ interpreter .set_tensor (input_name_to_index [k ], v )
400
+ interpreter .invoke ()
401
+ result = [interpreter .get_tensor (output ['index' ]) for output in output_details ]
402
+ return result
403
+ tf_results = run_tflite ()
404
+ if self .perf :
405
+ logger .info ("Running TFLite perf" )
406
+ start = time .time ()
407
+ for _ in range (PERFITER ):
408
+ _ = run_tflite ()
409
+ self .tf_runtime = time .time () - start
410
+ logger .info ("TFLite OK" )
411
+
364
412
if not self .run_tf_frozen :
365
413
inputs = {}
366
414
for k in input_names :
@@ -384,45 +432,50 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
384
432
self .tf_runtime = time .time () - start
385
433
logger .info ("TensorFlow OK" )
386
434
387
- inputs = {}
388
435
shape_override = {}
389
- tf_reset_default_graph ()
390
-
391
- from tf2onnx .tf_utils import compress_graph_def
392
436
const_node_values = None
393
- with tf .Graph ().as_default () as tf_graph :
394
- if self .large_model :
395
- const_node_values = compress_graph_def (graph_def )
396
- tf .import_graph_def (graph_def , name = '' )
437
+ tf_graph = None
397
438
398
- with tf_session (graph = tf_graph ) as sess :
399
- # create the input data
400
- for k in input_names :
401
- v = self .input_names [k ]
402
- t = sess .graph .get_tensor_by_name (k )
403
- expected_dtype = tf .as_dtype (t .dtype ).name
404
- if isinstance (v , six .text_type ) and v .startswith ("np." ):
405
- np_value = eval (v ) # pylint: disable=eval-used
406
- if expected_dtype != np_value .dtype :
407
- logger .warning ("dtype mismatch for input %s: expected=%s, actual=%s" , k , expected_dtype ,
408
- np_value .dtype )
409
- inputs [k ] = np_value .astype (expected_dtype )
410
- else :
411
- if expected_dtype == "string" :
412
- inputs [k ] = self .make_input (v ).astype (np .str ).astype (np .object )
439
+ if graph_def is not None :
440
+ inputs = {}
441
+ tf_reset_default_graph ()
442
+
443
+ with tf .Graph ().as_default () as tf_graph :
444
+ from tf2onnx .tf_utils import compress_graph_def
445
+ if self .large_model :
446
+ const_node_values = compress_graph_def (graph_def )
447
+ tf .import_graph_def (graph_def , name = '' )
448
+
449
+ with tf_session (graph = tf_graph ) as sess :
450
+ # create the input data
451
+ for k in input_names :
452
+ v = self .input_names [k ]
453
+ t = sess .graph .get_tensor_by_name (k )
454
+ expected_dtype = tf .as_dtype (t .dtype ).name
455
+ if isinstance (v , six .text_type ) and v .startswith ("np." ):
456
+ np_value = eval (v ) # pylint: disable=eval-used
457
+ if expected_dtype != np_value .dtype :
458
+ logger .warning ("dtype mismatch for input %s: expected=%s, actual=%s" , k , expected_dtype ,
459
+ np_value .dtype )
460
+ inputs [k ] = np_value .astype (expected_dtype )
413
461
else :
414
- inputs [k ] = self .make_input (v ).astype (expected_dtype )
462
+ if expected_dtype == "string" :
463
+ inputs [k ] = self .make_input (v ).astype (np .str ).astype (np .object )
464
+ else :
465
+ inputs [k ] = self .make_input (v ).astype (expected_dtype )
415
466
416
- if self .force_input_shape :
417
- for k , v in inputs .items ():
418
- shape_override [k ] = list (v .shape )
467
+ if self .force_input_shape :
468
+ for k , v in inputs .items ():
469
+ shape_override [k ] = list (v .shape )
470
+
471
+ # run the model with tensorflow
472
+ if self .skip_tensorflow :
473
+ logger .info ("TensorFlow SKIPPED" )
474
+ elif self .run_tf_frozen :
475
+ tf_results = self .run_tensorflow (sess , inputs )
476
+ logger .info ("TensorFlow OK" )
477
+ tf_graph = sess .graph
419
478
420
- # run the model with tensorflow
421
- if self .skip_tensorflow :
422
- logger .info ("TensorFlow SKIPPED" )
423
- elif self .run_tf_frozen :
424
- tf_results = self .run_tensorflow (sess , inputs )
425
- logger .info ("TensorFlow OK" )
426
479
427
480
model_proto = None
428
481
if self .skip_conversion :
@@ -436,10 +489,10 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
436
489
else :
437
490
try :
438
491
# convert model to onnx
439
- onnx_graph = self .to_onnx (sess . graph , opset = opset , extra_opset = extra_opset ,
492
+ onnx_graph = self .to_onnx (tf_graph , opset = opset , extra_opset = extra_opset ,
440
493
shape_override = shape_override , input_names = inputs .keys (),
441
494
const_node_values = const_node_values ,
442
- initialized_tables = initialized_tables )
495
+ initialized_tables = initialized_tables , tflite_path = tflite_path )
443
496
onnx_graph = optimizer .optimize_graph (onnx_graph )
444
497
print ("ONNX" , onnx_graph .dump_node_statistics ())
445
498
external_tensor_storage = ExternalTensorStorage () if self .large_model else None
@@ -538,6 +591,8 @@ def get_args():
538
591
parser .add_argument ("--opset" , type = int , default = None , help = "opset to use" )
539
592
parser .add_argument ("--extra_opset" , default = None ,
540
593
help = "extra opset with format like domain:version, e.g. com.microsoft:1" )
594
+ parser .add_argument ("--skip_tf_tests" , help = "skip non-tflite tests" , default = "False" )
595
+ parser .add_argument ("--skip_tflite_tests" , help = "skip tflite tests" , default = "False" )
541
596
parser .add_argument ("--verbose" , "-v" , help = "verbose output, option is additive" , action = "count" )
542
597
parser .add_argument ("--debug" , help = "debug mode" , action = "store_true" )
543
598
parser .add_argument ("--list" , help = "list tests" , action = "store_true" )
@@ -550,6 +605,8 @@ def get_args():
550
605
args = parser .parse_args ()
551
606
552
607
args .target = args .target .split ("," )
608
+ args .skip_tf_tests = args .skip_tf_tests .upper () == "TRUE"
609
+ args .skip_tflite_tests = args .skip_tflite_tests .upper () == "TRUE"
553
610
if args .extra_opset :
554
611
tokens = args .extra_opset .split (':' )
555
612
if len (tokens ) != 2 :
@@ -644,6 +701,13 @@ def main():
644
701
logger .info ("Skip %s: disabled" , test )
645
702
continue
646
703
704
+ if args .skip_tflite_tests and t .model_type == "tflite" :
705
+ logger .info ("Skip %s: tflite test" , test )
706
+ continue
707
+ if args .skip_tf_tests and t .model_type != "tflite" :
708
+ logger .info ("Skip %s: not tflite test" , test )
709
+ continue
710
+
647
711
condition , reason = t .check_opset_constraints (args .opset , args .extra_opset )
648
712
if not condition :
649
713
logger .info ("Skip %s: %s" , test , reason )
0 commit comments