Skip to content

Commit ddb4301

Browse files
authored
Merge pull request #581 from lucienwang1009/init_cond_loop
calculate condition initialization for Loop
2 parents d239d1c + 4c4eff2 commit ddb4301

File tree

3 files changed

+80
-13
lines changed

3 files changed

+80
-13
lines changed

tests/test_loops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,33 @@ def b(i, out_ta):
160160
output_names_with_port = ["i:0", "output_ta:0"]
161161
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
162162

163+
def test_while_loop_with_cond_init_false(self):
164+
i = tf.placeholder(tf.int32, (), name="input_1")
165+
inputs = tf.placeholder(tf.float32, (10,), name="input_2")
166+
167+
inputs_2 = tf.identity(inputs)
168+
input_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs_2)
169+
output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
170+
171+
c = lambda i, *_: tf.logical_and(i < 10, i >= 0)
172+
173+
def b(i, out_ta):
174+
new_i = tf.add(i, 1)
175+
x = input_ta.read(i)
176+
y = x + 3
177+
out_ta_new = out_ta.write(i, y)
178+
return new_i, out_ta_new
179+
180+
i_final, out_final = tf.while_loop(c, b, [i, output_ta])
181+
_ = tf.identity(i_final, name="i")
182+
_ = tf.identity(out_final.stack(), name="output_ta")
183+
input_names_with_port = ["input_1:0", "input_2:0"]
184+
feed_dict = {"input_1:0": np.array(20, dtype=np.int32),
185+
"input_2:0": np.array([2.0, 16.0, 5.0, 1.6, 5.0, 6.0, 7.0, 8.0, 9.0, 10.], dtype=np.float32)}
186+
187+
output_names_with_port = ["i:0", "output_ta:0"]
188+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
189+
163190
def test_map_fn(self):
164191
def fn0(elem):
165192
res = elem + elem * elem

tf2onnx/rewriter/loop_rewriter.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,19 @@ def rewrite(self, context):
4545
cell_g_info = context.cell_graph
4646
cond_g_info = context.cond_graph
4747

48-
# todo(pengwa): we don't check the case where loop body won't be executed at all.
48+
# create a dummy loop to calculate the init condition
49+
init_cond_output = self._create_subgraph_initial_cond(cond_g_info)
4950

5051
## create Loop body graph with existing nodes
5152

52-
# replace condition graph's inputs to be cell graph's outputs, because we want condition graph
53-
# to consumer cell graph outputs.
54-
for loop_var in cond_g_info.dependent_vars:
55-
self.g.replace_all_inputs(cond_g_info.nodes, loop_var.switch_true_identity_output.id,
56-
loop_var.next_iteration_input.id)
57-
5853
body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
5954
body_outputs = cond_g_info.outputs + cell_g_info.outputs
6055
for out_tensor_value_info in body_outputs:
6156
shape = out_tensor_value_info.shape
62-
utils.make_sure(shape is not None, "Shape of {} is None".format(out_tensor_value_info.id))
57+
utils.make_sure(
58+
shape is not None,
59+
"Conversion of Loop requries output shape [{}] exists".format(out_tensor_value_info.id)
60+
)
6361
out_tensor_value_info.shape = utils.create_vague_shape_like(shape)
6462

6563
loop_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, body_nodes, body_outputs)
@@ -90,7 +88,7 @@ def rewrite(self, context):
9088
loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0])
9189

9290
## create Loop node
93-
loop_node = self._create_loop_node(context, loop_props)
91+
loop_node = self._create_loop_node(context, loop_props, init_cond_output)
9492
if not loop_node:
9593
logger.error("failed to create loop node during rewrite")
9694
return REWRITER_RESULT.FAIL
@@ -104,7 +102,48 @@ def rewrite(self, context):
104102
logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb)
105103
return REWRITER_RESULT.FAIL
106104

107-
def _create_loop_node(self, context, loop_props):
105+
def _create_subgraph_initial_cond(self, cond_graph):
106+
"""Create subgraph to calculate initial cond."""
107+
# copy condition subgraph to parent graph
108+
copied_nodes = []
109+
name_scope = utils.make_name("copy")
110+
for node in cond_graph.nodes:
111+
new_name = "{}/{}".format(name_scope, node.name)
112+
new_outputs = ["{}/{}".format(name_scope, out) for out in node.output]
113+
# some inputs are out of cond_graph.nodes, keep them intact
114+
new_inputs = []
115+
for inp in node.input:
116+
if self.g.get_node_by_output(inp) in cond_graph.nodes:
117+
new_inputs.append("{}/{}".format(name_scope, inp))
118+
else:
119+
new_inputs.append(inp)
120+
121+
new_node = self.g.make_node(
122+
node.type, new_inputs, outputs=new_outputs,
123+
attr=node.attr, name=new_name,
124+
shapes=node.output_shapes, dtypes=node.output_dtypes,
125+
skip_conversion=node.skip_conversion, infer_shape_dtype=False
126+
)
127+
body_graphs = node.graph.contained_graphs.pop(node.name, None)
128+
if body_graphs:
129+
for attr_name, body_graph in body_graphs.items():
130+
body_graph.parent_graph = g
131+
new_node.set_body_graph_as_attr(attr_name, body_graph)
132+
copied_nodes.append(new_node)
133+
134+
# replace all inputs of condition graph by initializer (enter_input)
135+
for loop_var in cond_graph.dependent_vars:
136+
self.g.replace_all_inputs(
137+
copied_nodes,
138+
loop_var.next_iteration_input.id,
139+
loop_var.enter_input_id
140+
)
141+
init_cond_output = "{}/{}".format(name_scope, cond_graph.outputs[0].id)
142+
self.g.set_dtype(init_cond_output, cond_graph.outputs[0].dtype)
143+
self.g.set_shape(init_cond_output, cond_graph.outputs[0].shape)
144+
return init_cond_output
145+
146+
def _create_loop_node(self, context, loop_props, init_cond_output):
108147
loop_outputs = []
109148
loop_output_shapes = []
110149
loop_output_dtypes = []
@@ -123,8 +162,7 @@ def _create_loop_node(self, context, loop_props):
123162
# trip count and cond are not used, giving them values just because bug
124163
# (https://github.com/Microsoft/onnxruntime/issues/255) of onnxruntime.
125164
trip_cnt = self.g.make_const(utils.make_name("trip_count"), np.array(sys.maxsize, dtype=np.int64))
126-
cond = self.g.make_const(utils.make_name("cond"), np.array(True, dtype=np.bool))
127-
loop_node = self.g.make_node("Loop", [trip_cnt.output[0]] + [cond.output[0]] +
165+
loop_node = self.g.make_node("Loop", [trip_cnt.output[0]] + [init_cond_output] +
128166
loop_props.state_inputs_initial_values, # ONNX Loop support state inputs only
129167
outputs=loop_outputs, op_name_scope="generic_loop",
130168
shapes=loop_output_shapes, dtypes=loop_output_dtypes,

tf2onnx/rewriter/loop_rewriter_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,11 @@ def _crop_loop_condition_sub_graph(self, context):
295295
loop_var = context.loop_properties.all_variables[enter_node.name]
296296

297297
# cut off connection between condition graph and Merge node.
298+
# replace condition graph's inputs to be cell graph's outputs, because we want condition graph
299+
# to consumer cell graph outputs.
298300
non_switch_consumers = [n for n in self.g.find_output_consumers(merge_node.output[0]) if n.type != "Switch"]
299301
self.g.replace_all_inputs(non_switch_consumers, merge_node.output[0],
300-
loop_var.switch_true_identity_output.id)
302+
loop_var.next_iteration_input.id)
301303
dependent_vars.append(loop_var)
302304

303305
# cut off connection between condition graph and LoopCond node.

0 commit comments

Comments
 (0)