Skip to content

Commit b38a89c

Browse files
committed
add support dynamic max_iterations in loop v2
Signed-off-by: wangqiaoshi <[email protected]>
1 parent 9f922a9 commit b38a89c

File tree

1 file changed

+12
-19
lines changed

1 file changed

+12
-19
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -366,15 +366,7 @@ def version_7(cls, ctx, node, **kwargs):
366366
# consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
367367
# if maximum_iterations is not const,should add an cast node(cast to int64)
368368
maximum_iterations_name = node.input[1]
369-
cast_mark = False
370-
if node.inputs[1].type != "Const":
371-
cast_mark = True
372-
if cast_mark:
373-
cast_inputs = [maximum_iterations_name]
374-
attr = {"to": onnx_pb.TensorProto.INT64}
375-
cast_name = node.name + "_cast"
376-
cast_node = ctx.make_node("Cast", cast_inputs, attr, name=cast_name)
377-
else:
369+
if node.inputs[1].is_const():
378370
maximum_iterations = node.inputs[1].get_tensor_value()
379371
if maximum_iterations == -1:
380372
maximum_iterations = np.iinfo(np.int64).max
@@ -386,6 +378,13 @@ def version_7(cls, ctx, node, **kwargs):
386378
maximum_iterations_name = utils.make_name(node.inputs[1].name)
387379
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
388380
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
381+
maximum_iterations_int64 = maximum_iterations_name
382+
else:
383+
cast_inputs = [maximum_iterations_name]
384+
attr = {"to": onnx_pb.TensorProto.INT64}
385+
cast_name = node.name + "_cast"
386+
cast_node = ctx.make_node("Cast", cast_inputs, attr, name=cast_name)
387+
maximum_iterations_int64 = cast_node.output[0]
389388

390389
cond_name = node.get_attr_str("cond")
391390
cond_graph = find_function(cond_name)
@@ -461,16 +460,10 @@ def version_7(cls, ctx, node, **kwargs):
461460
output_names = output_names[2:]
462461

463462
branches = {"body": body}
464-
if cast_mark:
465-
loop_node = ctx.make_node("Loop", [cast_node.output[0], cond_outputs[0]] + loop_vars,
466-
output_count=len(output_shapes), name=node.name + "_loop",
467-
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True,
468-
branches=branches)
469-
else:
470-
loop_node = ctx.make_node("Loop", [maximum_iterations_name, cond_outputs[0]] + loop_vars,
471-
output_count=len(output_shapes), name=node.name + "_loop",
472-
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True,
473-
branches=branches)
463+
loop_node = ctx.make_node("Loop", [maximum_iterations_int64, cond_outputs[0]] + loop_vars,
464+
output_count=len(output_shapes), name=node.name + "_loop",
465+
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True,
466+
branches=branches)
474467

475468
output_map = dict(zip(output_names, loop_node.output))
476469

0 commit comments

Comments
 (0)