Skip to content

Commit 9f922a9

Browse files
11112968wangqiaoshi
authored andcommitted
add support dynamic max_iterations in loop
Signed-off-by: wangqiaoshi <[email protected]>
1 parent ac2f675 commit 9f922a9

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -362,20 +362,30 @@ def version_7(cls, ctx, node, **kwargs):
362362
# may be removed from output_names below
363363
output_names = node.output.copy()
364364

365-
# Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other consumers,
366-
# modify it in place. Otherwise, make a new const node and leave the original unchanged.
365+
# Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other
366+
# consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
367+
# if maximum_iterations is not const,should add an cast node(cast to int64)
367368
maximum_iterations_name = node.input[1]
368-
maximum_iterations = node.inputs[1].get_tensor_value()
369-
if maximum_iterations == -1:
370-
maximum_iterations = np.iinfo(np.int64).max
371-
consumers = ctx.find_output_consumers(maximum_iterations_name)
372-
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
373-
if len(external_consumers) == 0:
374-
ctx.remove_node(node.inputs[1].name)
369+
cast_mark = False
370+
if node.inputs[1].type != "Const":
371+
cast_mark = True
372+
if cast_mark:
373+
cast_inputs = [maximum_iterations_name]
374+
attr = {"to": onnx_pb.TensorProto.INT64}
375+
cast_name = node.name + "_cast"
376+
cast_node = ctx.make_node("Cast", cast_inputs, attr, name=cast_name)
375377
else:
376-
maximum_iterations_name = utils.make_name(node.inputs[1].name)
377-
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
378-
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
378+
maximum_iterations = node.inputs[1].get_tensor_value()
379+
if maximum_iterations == -1:
380+
maximum_iterations = np.iinfo(np.int64).max
381+
consumers = ctx.find_output_consumers(maximum_iterations_name)
382+
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
383+
if len(external_consumers) == 0:
384+
ctx.remove_node(node.inputs[1].name)
385+
else:
386+
maximum_iterations_name = utils.make_name(node.inputs[1].name)
387+
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
388+
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
379389

380390
cond_name = node.get_attr_str("cond")
381391
cond_graph = find_function(cond_name)
@@ -451,10 +461,16 @@ def version_7(cls, ctx, node, **kwargs):
451461
output_names = output_names[2:]
452462

453463
branches = {"body": body}
454-
loop_node = ctx.make_node("Loop", [maximum_iterations_name, cond_outputs[0]] + loop_vars,
455-
output_count=len(output_shapes), name=node.name + "_loop",
456-
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True,
457-
branches=branches)
464+
if cast_mark:
465+
loop_node = ctx.make_node("Loop", [cast_node.output[0], cond_outputs[0]] + loop_vars,
466+
output_count=len(output_shapes), name=node.name + "_loop",
467+
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True,
468+
branches=branches)
469+
else:
470+
loop_node = ctx.make_node("Loop", [maximum_iterations_name, cond_outputs[0]] + loop_vars,
471+
output_count=len(output_shapes), name=node.name + "_loop",
472+
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True,
473+
branches=branches)
458474

459475
output_map = dict(zip(output_names, loop_node.output))
460476

0 commit comments

Comments
 (0)