45
45
from tf2onnx .tfonnx import process_tf_graph
46
46
from tf2onnx .tf_loader import tf_session , tf_reset_default_graph
47
47
from tf2onnx .graph import ExternalTensorStorage
48
+ from tfjs_runner import run_tfjs
48
49
49
50
logger = logging .getLogger ("run_pretrained" )
50
51
@@ -251,6 +252,10 @@ def download_model(self):
251
252
elif self .model_type == 'tflite' :
252
253
fname = self .local
253
254
dir_name = fname .replace (".tflite" , "" ) + "_dir"
255
+ elif self .model_type == 'tfjs' :
256
+ ftype = 'tgz'
257
+ fname = 'model.tar.gz'
258
+ dir_name = "_" .join (url .split ("/" )[5 :- 3 ]) + "_dir"
254
259
dir_name = os .path .join (cache_dir , dir_name )
255
260
os .makedirs (dir_name , exist_ok = True )
256
261
fpath = os .path .join (dir_name , fname )
@@ -303,7 +308,8 @@ def run_tensorflow(self, sess, inputs):
303
308
return result
304
309
305
310
def to_onnx (self , tf_graph , opset = None , extra_opset = None , shape_override = None , input_names = None ,
306
- const_node_values = None , initialized_tables = None , tflite_path = None , tensors_to_rename = None ):
311
+ const_node_values = None , initialized_tables = None , tflite_path = None , tensors_to_rename = None ,
312
+ tfjs_path = None ):
307
313
"""Convert graph to tensorflow."""
308
314
if extra_opset is None :
309
315
extra_opset = []
@@ -314,7 +320,7 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
314
320
input_names = input_names , output_names = self .output_names ,
315
321
const_node_values = const_node_values , initialized_tables = initialized_tables ,
316
322
tflite_path = tflite_path , dequantize = self .dequantize ,
317
- tensors_to_rename = tensors_to_rename )
323
+ tensors_to_rename = tensors_to_rename , tfjs_path = tfjs_path )
318
324
319
325
def run_onnxruntime (self , name , model_proto , inputs , outputs , external_tensor_storage = None ):
320
326
"""Run test against onnxruntime backend."""
@@ -375,6 +381,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
375
381
initialized_tables = {}
376
382
outputs = self .output_names
377
383
tflite_path = None
384
+ tfjs_path = None
378
385
to_rename = {}
379
386
if self .model_type in ["checkpoint" ]:
380
387
graph_def , input_names , outputs = tf_loader .from_checkpoint (model_path , input_names , outputs )
@@ -394,6 +401,9 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
394
401
elif self .model_type in ["tflite" ]:
395
402
tflite_path = model_path
396
403
graph_def = None
404
+ elif self .model_type in ["tfjs" ]:
405
+ tfjs_path = model_path
406
+ graph_def = None
397
407
else :
398
408
graph_def , input_names , outputs = tf_loader .from_graphdef (model_path , input_names , outputs )
399
409
@@ -434,6 +444,16 @@ def run_tflite():
434
444
logger .info ("TFLite perf {:.2f}ms/inference, n={}" .format (self .tf_runtime , n ))
435
445
logger .info ("TFLite OK" )
436
446
447
+ if tfjs_path is not None :
448
+ inputs = {}
449
+ for k in input_names :
450
+ v = self .input_names [k ]
451
+ inputs [k ] = self .make_input (v )
452
+ if not self .skip_tensorflow :
453
+ logger .info ("Running TFJS" )
454
+ tf_results = run_tfjs (tfjs_path , inputs , dir_name )
455
+ logger .info ("TFJS OK" )
456
+
437
457
if not self .run_tf_frozen :
438
458
inputs = {}
439
459
for k in input_names :
@@ -465,7 +485,6 @@ def run_tflite():
465
485
logger .info ("TF perf {:.2f}ms/inference, n={}" .format (self .tf_runtime , n ))
466
486
logger .info ("TensorFlow OK" )
467
487
468
- shape_override = {}
469
488
const_node_values = None
470
489
tf_graph = None
471
490
@@ -497,10 +516,6 @@ def run_tflite():
497
516
else :
498
517
inputs [k ] = self .make_input (v ).astype (expected_dtype )
499
518
500
- if self .force_input_shape :
501
- for k , v in inputs .items ():
502
- shape_override [k ] = list (v .shape )
503
-
504
519
# run the model with tensorflow
505
520
if self .skip_tensorflow :
506
521
logger .info ("TensorFlow SKIPPED" )
@@ -526,11 +541,15 @@ def run_tflite():
526
541
else :
527
542
try :
528
543
# convert model to onnx
544
+ if self .force_input_shape :
545
+ shape_override = {k : list (v .shape ) for k , v in inputs .items ()}
546
+ else :
547
+ shape_override = None
529
548
onnx_graph = self .to_onnx (tf_graph , opset = opset , extra_opset = extra_opset ,
530
549
shape_override = shape_override , input_names = inputs .keys (),
531
550
const_node_values = const_node_values ,
532
551
initialized_tables = initialized_tables , tflite_path = tflite_path ,
533
- tensors_to_rename = to_rename )
552
+ tensors_to_rename = to_rename , tfjs_path = tfjs_path )
534
553
onnx_graph = optimizer .optimize_graph (onnx_graph )
535
554
print ("ONNX" , onnx_graph .dump_node_statistics ())
536
555
external_tensor_storage = ExternalTensorStorage () if self .large_model else None
@@ -636,6 +655,7 @@ def get_args():
636
655
help = "extra opset with format like domain:version, e.g. com.microsoft:1" )
637
656
parser .add_argument ("--skip_tf_tests" , help = "skip non-tflite tests" , default = "False" )
638
657
parser .add_argument ("--skip_tflite_tests" , help = "skip tflite tests" , default = "False" )
658
+ parser .add_argument ("--skip_tfjs_tests" , help = "skip tfjs tests" , default = "False" )
639
659
parser .add_argument ("--verbose" , "-v" , help = "verbose output, option is additive" , action = "count" )
640
660
parser .add_argument ("--debug" , help = "debug mode" , action = "store_true" )
641
661
parser .add_argument ("--list" , help = "list tests" , action = "store_true" )
@@ -647,6 +667,7 @@ def get_args():
647
667
args .target = args .target .split ("," )
648
668
args .skip_tf_tests = args .skip_tf_tests .upper () == "TRUE"
649
669
args .skip_tflite_tests = args .skip_tflite_tests .upper () == "TRUE"
670
+ args .skip_tfjs_tests = args .skip_tfjs_tests .upper () == "TRUE"
650
671
if args .extra_opset :
651
672
tokens = args .extra_opset .split (':' )
652
673
if len (tokens ) != 2 :
@@ -739,11 +760,14 @@ def main():
739
760
logger .info ("Skip %s: disabled" , test )
740
761
continue
741
762
763
+ if args .skip_tfjs_tests and t .model_type == "tfjs" :
764
+ logger .info ("Skip %s: tfjs test" , test )
765
+ continue
742
766
if args .skip_tflite_tests and t .model_type == "tflite" :
743
767
logger .info ("Skip %s: tflite test" , test )
744
768
continue
745
- if args .skip_tf_tests and t .model_type != "tflite" :
746
- logger .info ("Skip %s: not tflite test" , test )
769
+ if args .skip_tf_tests and t .model_type not in [ "tflite" , "tfjs" ] :
770
+ logger .info ("Skip %s: tf test" , test )
747
771
continue
748
772
749
773
condition , reason = t .check_opset_constraints (args .opset , args .extra_opset )
0 commit comments