Skip to content

Commit 48ca297

Browse files
authored
Merge pull request #337 from lucienwang1009/cond_inside_loop
rewrite the graph and body graph recursively
2 parents 271d260 + 9f0e385 commit 48ca297

File tree

4 files changed

+65
-31
lines changed

4 files changed

+65
-31
lines changed

tests/test_cond.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def test_simple_cond(self):
3535
output_names_with_port = ["output:0"]
3636
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
3737

38-
@unittest.skip("known issue about onnxruntime that initilizer is subgraph input")
3938
def test_cond_with_const_branch(self):
4039
x_val = np.array([1, 2, 3], dtype=np.float32)
4140
y_val = np.array([4, 5, 6], dtype=np.float32)
@@ -132,8 +131,7 @@ def cond_graph2():
132131
output_names_with_port = ["output:0"]
133132
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
134133

135-
@unittest.skip("not support for now")
136-
def test_cond_with_while_loop(self):
134+
def test_while_loop_between_conds(self):
137135
x_val = np.array([1, 2, 3], dtype=np.float32)
138136
y_val = np.array([4, 5, 6], dtype=np.float32)
139137
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
@@ -156,7 +154,6 @@ def cond_graph():
156154
output_names_with_port = ["output:0"]
157155
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
158156

159-
@unittest.skip("not support for now")
160157
def test_cond_in_while_loop(self):
161158
i = tf.placeholder(tf.int32, (), name="input_1")
162159
inputs = tf.placeholder(tf.float32, (10,), name="input_2")
@@ -170,7 +167,7 @@ def test_cond_in_while_loop(self):
170167
def b(i, out_ta):
171168
new_i = tf.add(i, 1)
172169
x = input_ta.read(i)
173-
x = tf.cond(x >= 0, lambda: x - 1, lambda: x + 3)
170+
x = tf.cond(x > 0, lambda: x - 1, lambda: x + 3)
174171
out_ta_new = out_ta.write(i, x)
175172
return new_i, out_ta_new
176173

tf2onnx/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
390390
if outputs is None:
391391
outputs = [name + ":" + str(i) for i in range(output_count)]
392392

393+
output_count = len(outputs)
394+
393395
raw_attr = {}
394396
onnx_attrs = []
395397
for a, v in attr.items():

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,28 +104,63 @@ def run(self):
104104

105105
return self.g.get_nodes()
106106

107+
def _get_output_shape_dtype(self, cond_context):
108+
output_shapes = []
109+
output_dtypes = []
110+
for i, _ in enumerate(cond_context.true_branch_context.output):
111+
true_output = cond_context.true_branch_context.output[i]
112+
false_output = cond_context.false_branch_context.output[i]
113+
true_shape = self.g.get_shape(true_output)
114+
true_dtype = self.g.get_dtype(true_output)
115+
false_shape = self.g.get_shape(false_output)
116+
false_dtype = self.g.get_dtype(false_output)
117+
if true_shape != false_shape:
118+
raise RuntimeError(
119+
"the shape of outputs {} and {} mismatch: {}, {}".format(
120+
true_output,
121+
false_output,
122+
true_shape,
123+
false_shape
124+
)
125+
)
126+
if true_dtype != false_dtype:
127+
raise RuntimeError(
128+
"the shape of outputs {} and {} mismatch: {}, {}".format(
129+
true_output,
130+
false_output,
131+
true_dtype,
132+
false_dtype
133+
)
134+
)
135+
output_shapes.append(true_shape)
136+
output_dtypes.append(true_dtype)
137+
return output_shapes, output_dtypes
138+
107139
def _create_if_node(self, cond_context):
140+
output_shapes, output_dtypes = self._get_output_shape_dtype(cond_context)
108141
if_node = self.g.make_node(
109142
"If",
110143
[cond_context.pred_input],
111144
op_name_scope=cond_context.cond_scope,
112145
outputs=[m.output[0] for m in cond_context.merges],
146+
shapes=output_shapes,
147+
dtypes=output_dtypes,
113148
skip_conversion=False
114149
)
115150
log.debug("set graph for if branchs")
116151
true_graph = utils.construct_graph_from_nodes(
117152
self.g,
118153
list(cond_context.true_branch_context.nodes),
119154
cond_context.true_branch_context.output,
120-
[self.g.get_shape(out) for out in cond_context.true_branch_context.output],
121-
[self.g.get_dtype(out) for out in cond_context.true_branch_context.output]
155+
output_shapes,
156+
output_dtypes
122157
)
123158
false_graph = utils.construct_graph_from_nodes(
124159
self.g,
125160
list(cond_context.false_branch_context.nodes),
126161
cond_context.false_branch_context.output,
127-
[self.g.get_shape(out) for out in cond_context.false_branch_context.output],
128-
[self.g.get_dtype(out) for out in cond_context.false_branch_context.output]
162+
output_shapes,
163+
output_dtypes
129164
)
130165
if_node.set_body_graph_as_attr("then_branch", true_graph)
131166
if_node.set_body_graph_as_attr("else_branch", false_graph)
@@ -268,7 +303,7 @@ def _branch_type(self, branch_output, nodes):
268303
if branch == BranchType.UNKNOWN:
269304
log.debug(
270305
"branch only contains const node: [%s]",
271-
",".join(n for n in nodes)
306+
",".join(n.name for n in nodes)
272307
)
273308
return branch
274309

tf2onnx/tfonnx.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2482,16 +2482,28 @@ def topological_sort(g, continue_on_error):
24822482
pass
24832483

24842484

2485-
def run_late_rewriters(g, funcs, continue_on_error):
2485+
def run_rewriters(g, funcs, continue_on_error):
2486+
"""Rewrite the original graph and body graphs of nodes"""
2487+
# NOTE(wayuanho):
2488+
# 1. we don't sort graph here, rewriter is expected to do it on its own.
2489+
# 2. the graph here may have circles, current topological_sort cannot handle it.
2490+
for func in funcs:
2491+
try:
2492+
ops = func(g, g.get_nodes())
2493+
g.set_nodes(ops)
2494+
except Exception as ex:
2495+
type_, value_, traceback_ = sys.exc_info()
2496+
log.error("rewriter %s: exception %s", func, ex)
2497+
ex_ext = traceback.format_exception(type_, value_, traceback_)
2498+
if continue_on_error:
2499+
log.info(ex_ext)
2500+
else:
2501+
raise ex
2502+
24862503
if g.contained_graphs:
24872504
for dict_val in g.contained_graphs.values():
24882505
for attr_name, b_g in dict_val.items():
2489-
run_late_rewriters(b_g, funcs, attr_name)
2490-
2491-
topological_sort(g, continue_on_error)
2492-
for func in funcs:
2493-
ops = func(g, g.get_nodes())
2494-
g.set_nodes(ops)
2506+
run_rewriters(b_g, funcs, attr_name)
24952507

24962508

24972509
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
@@ -2561,19 +2573,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
25612573
if custom_rewriter is not None:
25622574
rewriters.extend(custom_rewriter)
25632575

2564-
try:
2565-
ops = g.get_nodes()
2566-
for rewrite in rewriters:
2567-
ops = rewrite(g, ops)
2568-
g.set_nodes(ops)
2569-
except Exception as ex:
2570-
type_, value_, traceback_ = sys.exc_info()
2571-
log.error("node %s: exception %s" % (rewrite, ex))
2572-
ex_ext = traceback.format_exception(type_, value_, traceback_)
2573-
if continue_on_error:
2574-
log.info(ex_ext)
2575-
else:
2576-
raise ex
2576+
run_rewriters(g, rewriters, continue_on_error)
25772577

25782578
# some nodes may already copied into inner Graph, so remove them from main Graph.
25792579
g.delete_unused_nodes(output_names)
@@ -2590,7 +2590,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
25902590
if TARGET_RS6 in target:
25912591
late_rewriters.append(rewrite_incomplete_type_support_rs6)
25922592
if late_rewriters:
2593-
run_late_rewriters(g, late_rewriters, continue_on_error)
2593+
run_rewriters(g, late_rewriters, continue_on_error)
25942594

25952595
# onnx requires topological sorting
25962596
topological_sort(g, continue_on_error)

0 commit comments

Comments
 (0)