@@ -2419,19 +2419,6 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
2419
2419
return mapped_op , unmapped_op
2420
2420
2421
2421
2422
- def run_pre_rewriters (g , rewriters ):
2423
- ops = g .get_nodes ()
2424
- for rewrite in rewriters :
2425
- ops = rewrite (g , ops )
2426
- g .set_nodes (ops )
2427
- for node in ops :
2428
- body_graphs = node .get_body_graphs ()
2429
- if body_graphs :
2430
- for attr , b_g in body_graphs .items ():
2431
- log .debug ("start rewriting subgraph of %s's attribute %s" , node .name , attr )
2432
- run_pre_rewriters (b_g , rewriters )
2433
-
2434
-
2435
2422
def transpose_inputs (ctx , inputs_as_nchw ):
2436
2423
"""Insert a transpose from NHWC to NCHW on model input on users request."""
2437
2424
ops = []
@@ -2488,17 +2475,18 @@ def topological_sort(g, continue_on_error):
2488
2475
pass
2489
2476
2490
2477
2491
- def run_late_rewriters (g , funcs , continue_on_error ):
2492
- if g .contained_graphs :
2493
- for dict_val in g .contained_graphs .values ():
2494
- for attr_name , b_g in dict_val .items ():
2495
- run_late_rewriters (b_g , funcs , attr_name )
2496
-
2497
- topological_sort (g , continue_on_error )
2478
+ def run_rewriters (g , funcs , continue_on_error = False , need_sort = True ):
2479
+ if need_sort :
2480
+ topological_sort (g , continue_on_error )
2498
2481
for func in funcs :
2499
2482
ops = func (g , g .get_nodes ())
2500
2483
g .set_nodes (ops )
2501
2484
2485
+ if g .contained_graphs :
2486
+ for dict_val in g .contained_graphs .values ():
2487
+ for attr_name , b_g in dict_val .items ():
2488
+ run_rewriters (b_g , funcs , attr_name )
2489
+
2502
2490
2503
2491
def process_tf_graph (tf_graph , continue_on_error = False , verbose = False , target = None ,
2504
2492
opset = None , custom_op_handlers = None , custom_rewriter = None ,
@@ -2566,7 +2554,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
2566
2554
if custom_rewriter is not None :
2567
2555
rewriters .extend (custom_rewriter )
2568
2556
2569
- run_pre_rewriters (g , rewriters )
2557
+ run_rewriters (g , rewriters , need_sort = False )
2570
2558
2571
2559
# some nodes may already copied into inner Graph, so remove them from main Graph.
2572
2560
g .delete_unused_nodes (output_names )
@@ -2583,7 +2571,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
2583
2571
if TARGET_RS6 in target :
2584
2572
late_rewriters .append (rewrite_incomplete_type_support_rs6 )
2585
2573
if late_rewriters :
2586
- run_late_rewriters (g , late_rewriters , continue_on_error )
2574
+ run_rewriters (g , late_rewriters , continue_on_error , True )
2587
2575
2588
2576
# onnx requires topological sorting
2589
2577
topological_sort (g , continue_on_error )
0 commit comments