Skip to content

Commit eda5dc4

Browse files
author
wayuanho
committed
rewrite the graph and body graph recursively
1 parent 5433ef6 commit eda5dc4

File tree

4 files changed

+67
-23
lines changed

4 files changed

+67
-23
lines changed

tests/test_cond.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def test_simple_cond(self):
3535
output_names_with_port = ["output:0"]
3636
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
3737

38-
@unittest.skip("known issue about onnxruntime that initilizer is subgraph input")
3938
def test_cond_with_const_branch(self):
4039
x_val = np.array([1, 2, 3], dtype=np.float32)
4140
y_val = np.array([4, 5, 6], dtype=np.float32)
@@ -132,8 +131,7 @@ def cond_graph2():
132131
output_names_with_port = ["output:0"]
133132
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
134133

135-
@unittest.skip("not support for now")
136-
def test_cond_with_while_loop(self):
134+
def test_while_loop_between_conds(self):
137135
x_val = np.array([1, 2, 3], dtype=np.float32)
138136
y_val = np.array([4, 5, 6], dtype=np.float32)
139137
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
@@ -156,7 +154,6 @@ def cond_graph():
156154
output_names_with_port = ["output:0"]
157155
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
158156

159-
@unittest.skip("not support for now")
160157
def test_cond_in_while_loop(self):
161158
i = tf.placeholder(tf.int32, (), name="input_1")
162159
inputs = tf.placeholder(tf.float32, (10,), name="input_2")
@@ -170,7 +167,7 @@ def test_cond_in_while_loop(self):
170167
def b(i, out_ta):
171168
new_i = tf.add(i, 1)
172169
x = input_ta.read(i)
173-
x = tf.cond(x >= 0, lambda: x - 1, lambda: x + 3)
170+
x = tf.cond(x > 0, lambda: x - 1, lambda: x + 3)
174171
out_ta_new = out_ta.write(i, x)
175172
return new_i, out_ta_new
176173

tf2onnx/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
391391
if outputs is None:
392392
outputs = [name + ":" + str(i) for i in range(output_count)]
393393

394+
output_count = len(outputs)
395+
394396
raw_attr = {}
395397
onnx_attrs = []
396398
for a, v in attr.items():

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,28 +104,63 @@ def run(self):
104104

105105
return self.g.get_nodes()
106106

107+
def _get_output_shape_dtype(self, cond_context):
108+
output_shapes = []
109+
output_dtypes = []
110+
for i, _ in enumerate(cond_context.true_branch_context.output):
111+
true_output = cond_context.true_branch_context.output[i]
112+
false_output = cond_context.false_branch_context.output[i]
113+
true_shape = self.g.get_shape(true_output)
114+
true_dtype = self.g.get_dtype(true_output)
115+
false_shape = self.g.get_shape(false_output)
116+
false_dtype = self.g.get_dtype(false_output)
117+
if true_shape != false_shape:
118+
raise RuntimeError(
119+
"the shape of outputs {} and {} mismatch: {}, {}".format(
120+
true_output,
121+
false_output,
122+
true_shape,
123+
false_shape
124+
)
125+
)
126+
if true_dtype != false_dtype:
127+
raise RuntimeError(
128+
"the shape of outputs {} and {} mismatch: {}, {}".format(
129+
true_output,
130+
false_output,
131+
true_dtype,
132+
false_dtype
133+
)
134+
)
135+
output_shapes.append(true_shape)
136+
output_dtypes.append(true_dtype)
137+
return output_shapes, output_dtypes
138+
107139
def _create_if_node(self, cond_context):
140+
output_shapes, output_dtypes = self._get_output_shape_dtype(cond_context)
108141
if_node = self.g.make_node(
109142
"If",
110143
[cond_context.pred_input],
111144
op_name_scope=cond_context.cond_scope,
112145
outputs=[m.output[0] for m in cond_context.merges],
146+
shapes=output_shapes,
147+
dtypes=output_dtypes,
113148
skip_conversion=False
114149
)
115150
log.debug("set graph for if branchs")
116151
true_graph = utils.construct_graph_from_nodes(
117152
self.g,
118153
list(cond_context.true_branch_context.nodes),
119154
cond_context.true_branch_context.output,
120-
[self.g.get_shape(out) for out in cond_context.true_branch_context.output],
121-
[self.g.get_dtype(out) for out in cond_context.true_branch_context.output]
155+
output_shapes,
156+
output_dtypes
122157
)
123158
false_graph = utils.construct_graph_from_nodes(
124159
self.g,
125160
list(cond_context.false_branch_context.nodes),
126161
cond_context.false_branch_context.output,
127-
[self.g.get_shape(out) for out in cond_context.false_branch_context.output],
128-
[self.g.get_dtype(out) for out in cond_context.false_branch_context.output]
162+
output_shapes,
163+
output_dtypes
129164
)
130165
if_node.set_body_graph_as_attr("then_branch", true_graph)
131166
if_node.set_body_graph_as_attr("else_branch", false_graph)
@@ -268,7 +303,7 @@ def _branch_type(self, branch_output, nodes):
268303
if branch == BranchType.UNKNOWN:
269304
log.debug(
270305
"branch only contains const node: [%s]",
271-
",".join(n for n in nodes)
306+
",".join(n.name for n in nodes)
272307
)
273308
return branch
274309

tf2onnx/tfonnx.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,6 +2419,28 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
24192419
return mapped_op, unmapped_op
24202420

24212421

2422+
def tensorflow_onnx_rewrite(g, rewriters):
2423+
try:
2424+
ops = g.get_nodes()
2425+
for rewrite in rewriters:
2426+
ops = rewrite(g, ops)
2427+
g.set_nodes(ops)
2428+
for node in ops:
2429+
body_graphs = node.get_body_graphs()
2430+
if body_graphs:
2431+
for attr, b_g in body_graphs.items():
2432+
log.debug("start rewriting subgraph of %s's attribute %s", node.name, attr)
2433+
tensorflow_onnx_rewrite(b_g, rewriters)
2434+
except Exception as ex:
2435+
type_, value_, traceback_ = sys.exc_info()
2436+
log.error("node %s: exception %s" % (rewrite, ex))
2437+
ex_ext = traceback.format_exception(type_, value_, traceback_)
2438+
if continue_on_error:
2439+
log.info(ex_ext)
2440+
else:
2441+
raise ex
2442+
2443+
24222444
def transpose_inputs(ctx, inputs_as_nchw):
24232445
"""Insert a transpose from NHWC to NCHW on model input on users request."""
24242446
ops = []
@@ -2553,19 +2575,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
25532575
if custom_rewriter is not None:
25542576
rewriters.extend(custom_rewriter)
25552577

2556-
try:
2557-
ops = g.get_nodes()
2558-
for rewrite in rewriters:
2559-
ops = rewrite(g, ops)
2560-
g.set_nodes(ops)
2561-
except Exception as ex:
2562-
type_, value_, traceback_ = sys.exc_info()
2563-
log.error("node %s: exception %s" % (rewrite, ex))
2564-
ex_ext = traceback.format_exception(type_, value_, traceback_)
2565-
if continue_on_error:
2566-
log.info(ex_ext)
2567-
else:
2568-
raise ex
2578+
tensorflow_onnx_rewrite(g, rewriters)
25692579

25702580
# some nodes may already copied into inner Graph, so remove them from main Graph.
25712581
g.delete_unused_nodes(output_names)

0 commit comments

Comments
 (0)