Skip to content

Commit 4b81055

Browse files
Merge pull request #1240 from wangqiaoshi/support_dynamic_max_iterations_in_while
add support dynamic max_iterations in loop
2 parents e36dac2 + b38a89c commit 4b81055

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -355,20 +355,29 @@ def version_7(cls, ctx, node, **kwargs):
355355
# may be removed from output_names below
356356
output_names = node.output.copy()
357357

358-
# Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other consumers,
359-
# modify it in place. Otherwise, make a new const node and leave the original unchanged.
358+
# Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other
359+
# consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
360+
# if maximum_iterations is not const,should add an cast node(cast to int64)
360361
maximum_iterations_name = node.input[1]
361-
maximum_iterations = node.inputs[1].get_tensor_value()
362-
if maximum_iterations == -1:
363-
maximum_iterations = np.iinfo(np.int64).max
364-
consumers = ctx.find_output_consumers(maximum_iterations_name)
365-
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
366-
if len(external_consumers) == 0:
367-
ctx.remove_node(node.inputs[1].name)
362+
if node.inputs[1].is_const():
363+
maximum_iterations = node.inputs[1].get_tensor_value()
364+
if maximum_iterations == -1:
365+
maximum_iterations = np.iinfo(np.int64).max
366+
consumers = ctx.find_output_consumers(maximum_iterations_name)
367+
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
368+
if len(external_consumers) == 0:
369+
ctx.remove_node(node.inputs[1].name)
370+
else:
371+
maximum_iterations_name = utils.make_name(node.inputs[1].name)
372+
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
373+
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
374+
maximum_iterations_int64 = maximum_iterations_name
368375
else:
369-
maximum_iterations_name = utils.make_name(node.inputs[1].name)
370-
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
371-
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
376+
cast_inputs = [maximum_iterations_name]
377+
attr = {"to": onnx_pb.TensorProto.INT64}
378+
cast_name = node.name + "_cast"
379+
cast_node = ctx.make_node("Cast", cast_inputs, attr, name=cast_name)
380+
maximum_iterations_int64 = cast_node.output[0]
372381

373382
cond_name = node.get_attr_str("cond")
374383
cond_graph = find_function(cond_name)
@@ -444,7 +453,7 @@ def version_7(cls, ctx, node, **kwargs):
444453
output_names = output_names[2:]
445454

446455
branches = {"body": body}
447-
loop_node = ctx.make_node("Loop", [maximum_iterations_name, cond_outputs[0]] + loop_vars,
456+
loop_node = ctx.make_node("Loop", [maximum_iterations_int64, cond_outputs[0]] + loop_vars,
448457
output_count=len(output_shapes), name=node.name + "_loop",
449458
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True,
450459
branches=branches)

0 commit comments

Comments
 (0)