Skip to content

Commit b027bb2

Browse files
authored
Change Loop op with maximum iterations input M equals to empty string (#1971)
* make Loop op with maximum iterations M equal to empty string to match onnx spec Signed-off-by: Deyu Huang <[email protected]>
1 parent 89c4c5c commit b027bb2

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -381,29 +381,32 @@ def version_7(cls, ctx, node, **kwargs):
381381
# may be removed from output_names below
382382
output_names = node.output.copy()
383383

384-
# Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other
384+
# Make maximum_iterations int64. If the const node has no other
385385
# consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
386386
# if maximum_iterations is not const,should add an cast node(cast to int64)
387387
maximum_iterations_name = node.input[1]
388388
if node.inputs[1].is_const():
389389
maximum_iterations = node.inputs[1].get_tensor_value()
390-
if maximum_iterations == -1:
391-
maximum_iterations = np.iinfo(np.int64).max
392-
consumers = ctx.find_output_consumers(maximum_iterations_name)
393-
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
394-
if len(external_consumers) == 0:
395-
ctx.remove_node(node.inputs[1].name)
390+
# maximum_iterations with -1(tf) means it doesn't set the maximum count.
391+
# For onnx Loop op optional input `M`(int64), represents a maximum trip-count. Set empty string to skip.
392+
if maximum_iterations != -1:
393+
consumers = ctx.find_output_consumers(maximum_iterations_name)
394+
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
395+
if len(external_consumers) == 0:
396+
ctx.remove_node(node.inputs[1].name)
397+
else:
398+
maximum_iterations_name = utils.make_name(node.inputs[1].name)
399+
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
400+
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
401+
maximum_iterations_m = maximum_iterations_name
396402
else:
397-
maximum_iterations_name = utils.make_name(node.inputs[1].name)
398-
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
399-
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
400-
maximum_iterations_int64 = maximum_iterations_name
403+
maximum_iterations_m = ""
401404
else:
402405
cast_inputs = [maximum_iterations_name]
403406
attr = {"to": onnx_pb.TensorProto.INT64}
404407
cast_name = node.name + "_cast"
405408
cast_node = ctx.make_node("Cast", cast_inputs, attr, name=cast_name)
406-
maximum_iterations_int64 = cast_node.output[0]
409+
maximum_iterations_m = cast_node.output[0]
407410

408411
cond_name = node.get_attr_str("cond")
409412
cond_graph = find_function(cond_name)
@@ -427,7 +430,7 @@ def version_7(cls, ctx, node, **kwargs):
427430
cond_input_to_state_var[cond_graph.input_names[idx]] = maximum_iterations_name
428431
continue
429432
if idx < 2:
430-
# skip [0,1] loop_counter, max_iterations
433+
# skip [0,1] loop_counter, max_iterations
431434
continue
432435
n = node.inputs[idx]
433436
if n.type in ["TensorListReserve", "TensorListResize"]:
@@ -511,7 +514,7 @@ def version_7(cls, ctx, node, **kwargs):
511514
output_names = output_names[2:]
512515

513516
branches = {"body": body}
514-
loop_node = ctx.make_node("Loop", [maximum_iterations_int64, cond_outputs[0]] + loop_vars,
517+
loop_node = ctx.make_node("Loop", [maximum_iterations_m, cond_outputs[0]] + loop_vars,
515518
output_count=len(output_shapes), name=node.name + "_loop",
516519
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True,
517520
branches=branches)

0 commit comments

Comments
 (0)