Skip to content

Commit ac2f675

Browse files
Merge pull request #1238 from onnx/tom/multiscanout
Add support for multiple scan outputs
2 parents 06c1a32 + 0013b3a commit ac2f675

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

tests/test_loops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,40 @@ def b(i, out_ta):
188188
output_names_with_port = ["i:0", "output_ta:0"]
189189
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
190190

191+
def test_while_loop_with_multi_scan_outputs(self):
192+
def func(i, inputs1, inputs2):
193+
inputs1_ = tf.identity(inputs1)
194+
inputs2_ = tf.identity(inputs2)
195+
input_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs1_)
196+
input_ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs2_)
197+
output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
198+
output_ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
199+
200+
c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0)
201+
202+
def b(i, out_ta, out_ta2):
203+
new_i = tf.add(i, 1)
204+
x = input_ta.read(i)
205+
y = input_ta2.read(i)
206+
z = x + 3 + y
207+
p = x * y * 2
208+
out_ta_new = out_ta.write(i, z)
209+
out_ta_new2 = out_ta2.write(i, p)
210+
return new_i, out_ta_new, out_ta_new2
211+
212+
i_final, out_final, out_final2 = tf.while_loop(c, b, [i, output_ta, output_ta2])
213+
i_final_ = tf.identity(i_final, name="i")
214+
out_final_ = tf.identity(out_final.stack(), name="output_ta")
215+
out_final2_ = tf.identity(out_final2.stack(), name="output_ta2")
216+
return i_final_, out_final_, out_final2_
217+
218+
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
219+
feed_dict = {"input_1:0": np.array(0, dtype=np.int32),
220+
"input_2:0": np.array([2.0, 16.0, 5.0, 1.6, 5.0, 6.0, 7.0, 8.0, 9.0, 10.], dtype=np.float32),
221+
"input_3:0": np.array([1.0, 2.0, 3.0, 4.0, 5.0, 16.0, 7.0, 8.0, 9.0, 10.], dtype=np.float32)}
222+
output_names_with_port = ["i:0", "output_ta:0", "output_ta2:0"]
223+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
224+
191225
@check_onnxruntime_min_version(
192226
"0.5.0",
193227
"disable this case due to onnxruntime loop issue: https://github.com/microsoft/onnxruntime/issues/1272"

tf2onnx/onnx_opset/controlflow.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def version_7(cls, ctx, node, **kwargs):
421421
del output_names[idx]
422422
del body.outputs[idx]
423423

424-
removed_scan_outputs = {}
424+
scan_output_names = []
425425
# remove tensor array that are passed in to the loop
426426
for idx, n in reversed(to_remove):
427427
ctx.remove_node(n.name)
@@ -430,19 +430,15 @@ def version_7(cls, ctx, node, **kwargs):
430430
del body.func_inputs[idx]
431431
del cond_graph.func_inputs[idx]
432432
del tf_while_inputs[idx]
433-
# save the index of the scan output
434-
removed_scan_outputs[body.outputs[idx]] = idx
433+
scan_output_names.append(body.outputs[idx])
435434
del body.outputs[idx]
436-
# FIXME: Output shapes may be in wrong order if there are multiple scan outputs
437435
output_shapes.append(output_shapes[idx])
438436
output_dtypes.append(output_dtypes[idx])
439437
output_names.append(output_names[idx])
440438
del output_shapes[idx]
441439
del output_dtypes[idx]
442440
del output_names[idx]
443441

444-
utils.make_sure(len(removed_scan_outputs) <= 1, "converter only supports while loops with a single scan output")
445-
446442
ctx.remove_node(node.name)
447443

448444
# In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph
@@ -467,7 +463,7 @@ def version_7(cls, ctx, node, **kwargs):
467463
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()
468464

469465
wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
470-
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs)
466+
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names)
471467

472468
# if there was a tensorflow variant type, bind in a real type here
473469
# FIXME: I don't think this is needed anymore
@@ -477,7 +473,7 @@ def version_7(cls, ctx, node, **kwargs):
477473

478474

479475
def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
480-
output_dtypes, scope, parent, cond_graph, tf_while_inputs, removed_scan_outputs):
476+
output_dtypes, scope, parent, cond_graph, tf_while_inputs, scan_output_names):
481477
"""Wire subgraph graph into main."""
482478
remove_parents = []
483479
to_remove = []
@@ -521,9 +517,10 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
521517
g.replace_inputs(node, [node.input[2]])
522518
scan_outputs.append(node.output[0])
523519

524-
if len(scan_outputs) != len(removed_scan_outputs):
520+
if len(scan_outputs) != len(scan_output_names):
525521
raise ValueError("While loop couldn't find scan output index for nodes")
526522

523+
names_to_scan_outputs = {}
527524
for output in scan_outputs:
528525
last_output = output
529526
consumers = g.find_output_consumers(last_output)
@@ -533,10 +530,12 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
533530
raise ValueError("While loop couldn't find scan output index for node " + node.name)
534531
last_output = node.output[0]
535532
consumers = g.find_output_consumers(last_output)
536-
if last_output not in removed_scan_outputs:
533+
if last_output not in scan_output_names:
537534
raise ValueError("While loop couldn't find scan output index for node " + node.name)
538-
# TODO: store index to ensure scan outputs are in correct order for multiple outputs
539-
# initial_output_index = removed_scan_outputs[last_output]
535+
names_to_scan_outputs[last_output] = output
536+
537+
# Reorder scan outputs
538+
scan_outputs = [names_to_scan_outputs[name] for name in scan_output_names]
540539

541540
# remove all nodes feeding to TensorListSetItem's reserved tensor
542541
while remove_parents:

0 commit comments

Comments
 (0)