23
23
from tf2onnx .shape_inference import infer_shape
24
24
from tf2onnx .tf_loader import is_function , resolve_functions , set_function
25
25
from tf2onnx .tf_utils import tensorflow_to_onnx , get_tf_version , compute_const_folding_using_tf
26
- from tf2onnx .tflite_utils import read_tflite_model , parse_tflite_graph
26
+ from tf2onnx .tflite_utils import graphs_from_tflite
27
27
28
28
from . import constants , logging , schemas , utils , handler
29
29
@@ -417,61 +417,29 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
417
417
"please upgrade onnx package to avoid potential conversion issue." ,
418
418
utils .get_onnx_version (), opset )
419
419
420
- if shape_override is None :
421
- shape_override = {}
422
420
if inputs_as_nchw is None :
423
421
inputs_as_nchw = []
424
- if target is None :
425
- target = constants .DEFAULT_TARGET
426
-
427
- def check_io (input_names , output_names , output_shapes ):
428
- io_to_check = []
429
- if input_names :
430
- io_to_check .extend (input_names )
431
- if output_names :
432
- io_to_check .extend (output_names )
433
- if io_to_check :
434
- # check output existence in case user passed in wrong output ids
435
- non_exists = set (io_to_check ) - set (output_shapes .keys ())
436
- if non_exists :
437
- logger .error ("\n Failed to convert: inputs/outputs specified do not exist, make sure your passed"
438
- "in format: input/output_node_name:port_id. Problematic inputs/outputs are: %s \n " ,
439
- non_exists )
440
- raise ValueError ("Inputs/Outputs Not Found" )
441
422
423
+ is_tflite = False
442
424
if tflite_path is not None :
443
- tflite_graphs , opcodes , model , tensor_shapes = read_tflite_model (tflite_path )
444
- main_g = None
445
- subgraphs = []
446
- for i , tfl_graph in enumerate (tflite_graphs ):
447
- is_main_g = i == len (tflite_graphs ) - 1
448
- prefix = '' if is_main_g else tfl_graph .Name ().decode () + '_'
449
- tensor_shapes_from_interpreter = None
450
- if is_main_g :
451
- tensor_shapes_from_interpreter = tensor_shapes
452
- onnx_nodes , _ , _ , output_shapes , dtypes , f_inputs , f_outputs , graph_name = \
453
- parse_tflite_graph (tfl_graph , opcodes , model , prefix , tensor_shapes_from_interpreter )
454
- g_inputs = f_inputs
455
- g_outputs = f_outputs
456
- if is_main_g :
457
- # Override IO in main graph
458
- check_io (input_names , output_names , output_shapes )
459
- if input_names is not None :
460
- g_inputs = input_names
461
- if output_names is not None :
462
- g_outputs = output_names
463
- g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , g_inputs , g_outputs ,
464
- not is_main_g , graph_name )
465
- if is_main_g :
466
- main_g = g
467
- else :
468
- subgraphs .append (g )
425
+ main_g , subgraphs = graphs_from_tflite (tflite_path , input_names , output_names )
426
+ is_tflite = True
427
+ else :
428
+ main_g , subgraphs = graphs_from_tf (tf_graph , input_names , output_names , shape_override , const_node_values ,
429
+ ignore_default , use_default )
469
430
470
- g = process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
471
- target , {}, tensors_to_rename , is_tflite = True , dequantize = dequantize )
472
- return g
431
+ for g in [main_g ] + subgraphs :
432
+ g .set_config (target , opset , extra_opset )
433
+ g = process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
434
+ initialized_tables , tensors_to_rename , is_tflite , dequantize )
435
+ return g
473
436
474
- # make tf2onnx internal subgraphs from the tensorflow subgraphs
437
+
438
+ def graphs_from_tf (tf_graph , input_names , output_names , shape_override = None , const_node_values = None ,
439
+ ignore_default = None , use_default = None ):
440
+ """make tf2onnx internal subgraphs from the tensorflow subgraphs"""
441
+ if shape_override is None :
442
+ shape_override = {}
475
443
ordered_func = resolve_functions (tf_graph )
476
444
subgraphs = []
477
445
for func in ordered_func :
@@ -483,7 +451,7 @@ def check_io(input_names, output_names, output_shapes):
483
451
onnx_nodes , _ , _ , output_shapes , dtypes , _ = \
484
452
tensorflow_to_onnx (func , shape_override , const_node_values , ignore_default , use_default )
485
453
486
- fg = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , f_inputs_names , f_output_names ,
454
+ fg = Graph (onnx_nodes , output_shapes , dtypes , input_names = f_inputs_names , output_names = f_output_names ,
487
455
is_subgraph = True , graph_name = func .name )
488
456
fold_constants_using_tf (fg , outputs_to_values )
489
457
subgraphs .append (fg )
@@ -497,33 +465,30 @@ def check_io(input_names, output_names, output_shapes):
497
465
onnx_nodes , _ , _ , output_shapes , dtypes , _ = \
498
466
tensorflow_to_onnx (tf_graph , shape_override , const_node_values , ignore_default , use_default )
499
467
500
- check_io (input_names , output_names , output_shapes )
501
- main_g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , input_names , output_names ,
502
- is_subgraph )
468
+ utils .check_io (input_names , output_names , output_shapes .keys ())
469
+ main_g = Graph (onnx_nodes , output_shapes , dtypes , input_names = input_names , output_names = output_names )
503
470
fold_constants_using_tf (main_g , outputs_to_values )
504
- g = process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
505
- target , initialized_tables , tensors_to_rename )
506
- return g
471
+ return main_g , subgraphs
507
472
508
473
509
- def process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
474
+ def process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
510
475
initialized_tables , tensors_to_rename , is_tflite = False , dequantize = False ):
511
476
512
477
if tensors_to_rename is not None :
513
478
main_g .rename_tensors (tensors_to_rename )
514
479
inputs_as_nchw = [tensors_to_rename .get (t , t ) for t in inputs_as_nchw ]
515
480
516
481
for g in subgraphs :
517
- fg = process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
482
+ fg = process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
518
483
initialized_tables , is_tflite , dequantize )
519
484
set_function (fg .graph_name , fg )
520
- g = process_parsed_graph (main_g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
485
+ g = process_parsed_graph (main_g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
521
486
initialized_tables , is_tflite ,
522
487
dequantize )
523
488
return g
524
489
525
490
526
- def process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
491
+ def process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
527
492
initialized_tables , is_tflite = False , dequantize = False ):
528
493
529
494
op_cnt , attr_cnt = g .dump_node_statistics (include_attrs = True , include_subgraphs = False )
@@ -628,11 +593,11 @@ def compat_handler(ctx, node, **kwargs):
628
593
629
594
# post-processing rewriters
630
595
late_rewriters = []
631
- if constants .TARGET_RS5 in target :
596
+ if g . is_target ( constants .TARGET_RS5 ) :
632
597
late_rewriters .append (rewrite_incomplete_type_support_rs5 )
633
- if constants .TARGET_RS6 in target :
598
+ if g . is_target ( constants .TARGET_RS6 ) :
634
599
late_rewriters .append (rewrite_incomplete_type_support_rs6 )
635
- if constants .TARGET_CHANNELS_LAST in target :
600
+ if g . is_target ( constants .TARGET_CHANNELS_LAST ) :
636
601
late_rewriters .append (rewrite_channels_last )
637
602
if late_rewriters :
638
603
run_rewriters (g , late_rewriters , continue_on_error )
0 commit comments