Skip to content

Commit a7fba03

Browse files
author
wayuanho
committed
coordinate pre and late rewrite
1 parent 090434e commit a7fba03

File tree

1 file changed

+10
-22
lines changed

1 file changed

+10
-22
lines changed

tf2onnx/tfonnx.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,19 +2419,6 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
24192419
return mapped_op, unmapped_op
24202420

24212421

2422-
def run_pre_rewriters(g, rewriters):
2423-
ops = g.get_nodes()
2424-
for rewrite in rewriters:
2425-
ops = rewrite(g, ops)
2426-
g.set_nodes(ops)
2427-
for node in ops:
2428-
body_graphs = node.get_body_graphs()
2429-
if body_graphs:
2430-
for attr, b_g in body_graphs.items():
2431-
log.debug("start rewriting subgraph of %s's attribute %s", node.name, attr)
2432-
run_pre_rewriters(b_g, rewriters)
2433-
2434-
24352422
def transpose_inputs(ctx, inputs_as_nchw):
24362423
"""Insert a transpose from NHWC to NCHW on model input on users request."""
24372424
ops = []
@@ -2488,17 +2475,18 @@ def topological_sort(g, continue_on_error):
24882475
pass
24892476

24902477

2491-
def run_late_rewriters(g, funcs, continue_on_error):
2492-
if g.contained_graphs:
2493-
for dict_val in g.contained_graphs.values():
2494-
for attr_name, b_g in dict_val.items():
2495-
run_late_rewriters(b_g, funcs, attr_name)
2496-
2497-
topological_sort(g, continue_on_error)
2478+
def run_rewriters(g, funcs, continue_on_error=False, need_sort=True):
2479+
if need_sort:
2480+
topological_sort(g, continue_on_error)
24982481
for func in funcs:
24992482
ops = func(g, g.get_nodes())
25002483
g.set_nodes(ops)
25012484

2485+
if g.contained_graphs:
2486+
for dict_val in g.contained_graphs.values():
2487+
for attr_name, b_g in dict_val.items():
2488+
run_rewriters(b_g, funcs, attr_name)
2489+
25022490

25032491
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
25042492
opset=None, custom_op_handlers=None, custom_rewriter=None,
@@ -2566,7 +2554,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
25662554
if custom_rewriter is not None:
25672555
rewriters.extend(custom_rewriter)
25682556

2569-
run_pre_rewriters(g, rewriters)
2557+
run_rewriters(g, rewriters, need_sort=False)
25702558

25712559
# some nodes may already copied into inner Graph, so remove them from main Graph.
25722560
g.delete_unused_nodes(output_names)
@@ -2583,7 +2571,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
25832571
if TARGET_RS6 in target:
25842572
late_rewriters.append(rewrite_incomplete_type_support_rs6)
25852573
if late_rewriters:
2586-
run_late_rewriters(g, late_rewriters, continue_on_error)
2574+
run_rewriters(g, late_rewriters, continue_on_error, True)
25872575

25882576
# onnx requires topological sorting
25892577
topological_sort(g, continue_on_error)

0 commit comments

Comments
 (0)