@@ -381,29 +381,32 @@ def version_7(cls, ctx, node, **kwargs):
381
381
# may be removed from output_names below
382
382
output_names = node .output .copy ()
383
383
384
- # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx) . If the const node has no other
384
+ # Make maximum_iterations int64. If the const node has no other
385
385
# consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
386
386
# if maximum_iterations is not const,should add an cast node(cast to int64)
387
387
maximum_iterations_name = node .input [1 ]
388
388
if node .inputs [1 ].is_const ():
389
389
maximum_iterations = node .inputs [1 ].get_tensor_value ()
390
- if maximum_iterations == - 1 :
391
- maximum_iterations = np .iinfo (np .int64 ).max
392
- consumers = ctx .find_output_consumers (maximum_iterations_name )
393
- external_consumers = [c for c in consumers if c != node and c .type != 'TensorListReserve' ]
394
- if len (external_consumers ) == 0 :
395
- ctx .remove_node (node .inputs [1 ].name )
390
+ # maximum_iterations with -1(tf) means it doesn't set the maximum count.
391
+ # For onnx Loop op optional input `M`(int64), represents a maximum trip-count. Set empty string to skip.
392
+ if maximum_iterations != - 1 :
393
+ consumers = ctx .find_output_consumers (maximum_iterations_name )
394
+ external_consumers = [c for c in consumers if c != node and c .type != 'TensorListReserve' ]
395
+ if len (external_consumers ) == 0 :
396
+ ctx .remove_node (node .inputs [1 ].name )
397
+ else :
398
+ maximum_iterations_name = utils .make_name (node .inputs [1 ].name )
399
+ ctx .make_const (maximum_iterations_name , np .array (maximum_iterations , dtype = np .int64 ))
400
+ ctx .replace_input (node , node .input [1 ], maximum_iterations_name , 1 )
401
+ maximum_iterations_m = maximum_iterations_name
396
402
else :
397
- maximum_iterations_name = utils .make_name (node .inputs [1 ].name )
398
- ctx .make_const (maximum_iterations_name , np .array (maximum_iterations , dtype = np .int64 ))
399
- ctx .replace_input (node , node .input [1 ], maximum_iterations_name , 1 )
400
- maximum_iterations_int64 = maximum_iterations_name
403
+ maximum_iterations_m = ""
401
404
else :
402
405
cast_inputs = [maximum_iterations_name ]
403
406
attr = {"to" : onnx_pb .TensorProto .INT64 }
404
407
cast_name = node .name + "_cast"
405
408
cast_node = ctx .make_node ("Cast" , cast_inputs , attr , name = cast_name )
406
- maximum_iterations_int64 = cast_node .output [0 ]
409
+ maximum_iterations_m = cast_node .output [0 ]
407
410
408
411
cond_name = node .get_attr_str ("cond" )
409
412
cond_graph = find_function (cond_name )
@@ -427,7 +430,7 @@ def version_7(cls, ctx, node, **kwargs):
427
430
cond_input_to_state_var [cond_graph .input_names [idx ]] = maximum_iterations_name
428
431
continue
429
432
if idx < 2 :
430
- # skip [0,1] loop_counter, max_iterations
433
+ # skip [0,1] loop_counter, max_iterations
431
434
continue
432
435
n = node .inputs [idx ]
433
436
if n .type in ["TensorListReserve" , "TensorListResize" ]:
@@ -511,7 +514,7 @@ def version_7(cls, ctx, node, **kwargs):
511
514
output_names = output_names [2 :]
512
515
513
516
branches = {"body" : body }
514
- loop_node = ctx .make_node ("Loop" , [maximum_iterations_int64 , cond_outputs [0 ]] + loop_vars ,
517
+ loop_node = ctx .make_node ("Loop" , [maximum_iterations_m , cond_outputs [0 ]] + loop_vars ,
515
518
output_count = len (output_shapes ), name = node .name + "_loop" ,
516
519
shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
517
520
branches = branches )
0 commit comments