33
33
# pylint: disable=useless-return,broad-except,logging-not-lazy,unused-argument,missing-docstring
34
34
# pylint: disable=unused-variable
35
35
36
- def fold_constants_using_tf (g , outputs_to_values , outputs_to_dtypes ):
36
+ def fold_constants_using_tf (g , outputs_to_values ):
37
37
ops = list (g .get_nodes ())
38
38
# pylint: disable=too-many-nested-blocks
39
39
keep_looking = True
@@ -409,14 +409,13 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
409
409
del verbose
410
410
411
411
opset = utils .find_opset (opset )
412
- if not is_subgraph :
413
- logger .info ("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s" ,
414
- get_tf_version (), utils .get_onnx_version (), tf2onnx .__version__ , tf2onnx .version .git_version [:6 ])
415
- logger .info ("Using opset <onnx, %s>" , opset )
416
- if opset > schemas .get_max_supported_opset_version ():
417
- logger .warning ("Currently installed onnx package %s is too low to support opset %s, "
418
- "please upgrade onnx package to avoid potential conversion issue." ,
419
- utils .get_onnx_version (), opset )
412
+ logger .info ("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s" ,
413
+ get_tf_version (), utils .get_onnx_version (), tf2onnx .__version__ , tf2onnx .version .git_version [:6 ])
414
+ logger .info ("Using opset <onnx, %s>" , opset )
415
+ if opset > schemas .get_max_supported_opset_version ():
416
+ logger .warning ("Currently installed onnx package %s is too low to support opset %s, "
417
+ "please upgrade onnx package to avoid potential conversion issue." ,
418
+ utils .get_onnx_version (), opset )
420
419
421
420
if shape_override is None :
422
421
shape_override = {}
@@ -440,34 +439,17 @@ def check_io(input_names, output_names, output_shapes):
440
439
non_exists )
441
440
raise ValueError ("Inputs/Outputs Not Found" )
442
441
443
- def rename_tensors_in_dict (d ):
444
- if tensors_to_rename is None :
445
- return d
446
- return {tensors_to_rename .get (k , k ): v for k , v in d .items ()}
447
-
448
- def rename_tensors_in_list (tensors ):
449
- if tensors_to_rename is None or tensors is None :
450
- return tensors
451
- return [tensors_to_rename .get (t , t ) for t in tensors ]
452
-
453
- def rename_tensors_in_nodes (onnx_nodes ):
454
- if tensors_to_rename is None :
455
- return
456
- for n in onnx_nodes :
457
- n .input [:] = rename_tensors_in_list (n .input )
458
- n .output [:] = rename_tensors_in_list (n .output )
459
-
460
442
if tflite_path is not None :
461
443
tflite_graphs , opcodes , model , tensor_shapes = read_tflite_model (tflite_path )
462
444
main_g = None
463
- inputs_as_nchw = rename_tensors_in_list ( inputs_as_nchw )
445
+ subgraphs = []
464
446
for i , tfl_graph in enumerate (tflite_graphs ):
465
447
is_main_g = i == len (tflite_graphs ) - 1
466
448
prefix = '' if is_main_g else tfl_graph .Name ().decode () + '_'
467
449
tensor_shapes_from_interpreter = None
468
450
if is_main_g :
469
451
tensor_shapes_from_interpreter = tensor_shapes
470
- onnx_nodes , op_cnt , attr_cnt , output_shapes , dtypes , f_inputs , f_outputs , graph_name = \
452
+ onnx_nodes , _ , _ , output_shapes , dtypes , f_inputs , f_outputs , graph_name = \
471
453
parse_tflite_graph (tfl_graph , opcodes , model , prefix , tensor_shapes_from_interpreter )
472
454
g_inputs = f_inputs
473
455
g_outputs = f_outputs
@@ -478,63 +460,73 @@ def rename_tensors_in_nodes(onnx_nodes):
478
460
g_inputs = input_names
479
461
if output_names is not None :
480
462
g_outputs = output_names
481
- rename_tensors_in_nodes (onnx_nodes )
482
- g_inputs = rename_tensors_in_list (g_inputs )
483
- g_outputs = rename_tensors_in_list (g_outputs )
484
- output_shapes = rename_tensors_in_dict (output_shapes )
485
- dtypes = rename_tensors_in_dict (dtypes )
486
- g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , g_inputs , g_outputs , is_subgraph )
487
- fg = process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
488
- g_outputs , {}, {}, {}, op_cnt , attr_cnt , is_tflite = True , dequantize = dequantize )
489
- fg .graph_name = graph_name
463
+ g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , g_inputs , g_outputs ,
464
+ not is_main_g , graph_name )
490
465
if is_main_g :
491
- main_g = fg
466
+ main_g = g
492
467
else :
493
- set_function (graph_name , fg )
468
+ subgraphs .append (g )
469
+
470
+ g = process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
471
+ target , {}, tensors_to_rename , is_tflite = True , dequantize = dequantize )
472
+ return g
473
+
474
+ # make tf2onnx internal subgraphs from the tensorflow subgraphs
475
+ ordered_func = resolve_functions (tf_graph )
476
+ subgraphs = []
477
+ for func in ordered_func :
478
+ f_inputs_names = [t .name for t in func .inputs ]
479
+ f_output_names = [t .name for t in func .outputs ]
480
+
481
+ outputs_to_values , _ = compute_const_folding_using_tf (func , const_node_values , output_names )
482
+
483
+ onnx_nodes , _ , _ , output_shapes , dtypes , _ = \
484
+ tensorflow_to_onnx (func , shape_override , const_node_values , ignore_default , use_default )
494
485
495
- return main_g
486
+ fg = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , f_inputs_names , f_output_names ,
487
+ is_subgraph = True , graph_name = func .name )
488
+ fold_constants_using_tf (fg , outputs_to_values )
489
+ subgraphs .append (fg )
496
490
497
491
is_func = is_function (tf_graph )
498
492
if not is_func :
499
493
tf_graph = infer_shape (tf_graph , shape_override )
500
494
501
- outputs_to_values , outputs_to_dtypes = compute_const_folding_using_tf (tf_graph , const_node_values , output_names )
495
+ outputs_to_values , _ = compute_const_folding_using_tf (tf_graph , const_node_values , output_names )
502
496
503
- onnx_nodes , op_cnt , attr_cnt , output_shapes , dtypes , _ = \
497
+ onnx_nodes , _ , _ , output_shapes , dtypes , _ = \
504
498
tensorflow_to_onnx (tf_graph , shape_override , const_node_values , ignore_default , use_default )
505
- if not is_subgraph :
506
- # make tf2onnx internal subgraphs from the tensorflow subgraphs
507
- ordered_func = resolve_functions (tf_graph )
508
- for func in ordered_func :
509
- f_inputs_names = [t .name for t in func .inputs ]
510
- f_output_names = [t .name for t in func .outputs ]
511
- fg = process_tf_graph (func , continue_on_error , False , target , opset ,
512
- custom_op_handlers , custom_rewriter ,
513
- extra_opset , shape_override , inputs_as_nchw ,
514
- f_inputs_names , f_output_names , is_subgraph = True ,
515
- const_node_values = const_node_values , tensors_to_rename = tensors_to_rename ,
516
- initialized_tables = initialized_tables )
517
- fg .graph_name = func .name
518
- set_function (func .name , fg )
519
499
520
500
check_io (input_names , output_names , output_shapes )
501
+ main_g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , input_names , output_names ,
502
+ is_subgraph )
503
+ fold_constants_using_tf (main_g , outputs_to_values )
504
+ g = process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
505
+ target , initialized_tables , tensors_to_rename )
506
+ return g
507
+
508
+
509
+ def process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
510
+ initialized_tables , tensors_to_rename , is_tflite = False , dequantize = False ):
521
511
522
- if not is_subgraph :
523
- rename_tensors_in_nodes (onnx_nodes )
524
- input_names = rename_tensors_in_list (input_names )
525
- output_names = rename_tensors_in_list (output_names )
526
- output_shapes = rename_tensors_in_dict (output_shapes )
527
- dtypes = rename_tensors_in_dict (dtypes )
528
- inputs_as_nchw = rename_tensors_in_list (inputs_as_nchw )
529
- g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , input_names , output_names , is_subgraph )
530
- g = process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
531
- output_names , initialized_tables , outputs_to_values , outputs_to_dtypes , op_cnt , attr_cnt )
512
+ if tensors_to_rename is not None :
513
+ main_g .rename_tensors (tensors_to_rename )
514
+ inputs_as_nchw = [tensors_to_rename .get (t , t ) for t in inputs_as_nchw ]
515
+
516
+ for g in subgraphs :
517
+ fg = process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
518
+ initialized_tables , is_tflite , dequantize )
519
+ set_function (fg .graph_name , fg )
520
+ g = process_parsed_graph (main_g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
521
+ initialized_tables , is_tflite ,
522
+ dequantize )
532
523
return g
533
524
534
525
535
526
def process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
536
- output_names , initialized_tables , outputs_to_values , outputs_to_dtypes , op_cnt , attr_cnt ,
537
- is_tflite = False , dequantize = False ):
527
+ initialized_tables , is_tflite = False , dequantize = False ):
528
+
529
+ op_cnt , attr_cnt = g .dump_node_statistics (include_attrs = True , include_subgraphs = False )
538
530
539
531
if is_tflite :
540
532
tfl_rewriters = []
@@ -587,8 +579,6 @@ def compat_handler(ctx, node, **kwargs):
587
579
if inputs_as_nchw :
588
580
transpose_inputs (g , inputs_as_nchw )
589
581
590
- fold_constants_using_tf (g , outputs_to_values , outputs_to_dtypes )
591
-
592
582
# pre-processing graph rewrites
593
583
# bi-directional re-writer should be placed after single directional re-writer
594
584
rewriters = [
@@ -626,7 +616,7 @@ def compat_handler(ctx, node, **kwargs):
626
616
run_rewriters (g , rewriters , continue_on_error )
627
617
628
618
# some nodes may already copied into inner Graph, so remove them from main Graph.
629
- g .delete_unused_nodes (output_names )
619
+ g .delete_unused_nodes (g . outputs )
630
620
topological_sort (g , continue_on_error )
631
621
632
622
mapped_op , unmapped_op , exceptions = \
0 commit comments