@@ -355,20 +355,29 @@ def version_7(cls, ctx, node, **kwargs):
355
355
# may be removed from output_names below
356
356
output_names = node .output .copy ()
357
357
358
- # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other consumers,
359
- # modify it in place. Otherwise, make a new const node and leave the original unchanged.
358
+ # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other
359
+ # consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
360
+ # if maximum_iterations is not const,should add an cast node(cast to int64)
360
361
maximum_iterations_name = node .input [1 ]
361
- maximum_iterations = node .inputs [1 ].get_tensor_value ()
362
- if maximum_iterations == - 1 :
363
- maximum_iterations = np .iinfo (np .int64 ).max
364
- consumers = ctx .find_output_consumers (maximum_iterations_name )
365
- external_consumers = [c for c in consumers if c != node and c .type != 'TensorListReserve' ]
366
- if len (external_consumers ) == 0 :
367
- ctx .remove_node (node .inputs [1 ].name )
362
+ if node .inputs [1 ].is_const ():
363
+ maximum_iterations = node .inputs [1 ].get_tensor_value ()
364
+ if maximum_iterations == - 1 :
365
+ maximum_iterations = np .iinfo (np .int64 ).max
366
+ consumers = ctx .find_output_consumers (maximum_iterations_name )
367
+ external_consumers = [c for c in consumers if c != node and c .type != 'TensorListReserve' ]
368
+ if len (external_consumers ) == 0 :
369
+ ctx .remove_node (node .inputs [1 ].name )
370
+ else :
371
+ maximum_iterations_name = utils .make_name (node .inputs [1 ].name )
372
+ ctx .make_const (maximum_iterations_name , np .array (maximum_iterations , dtype = np .int64 ))
373
+ ctx .replace_input (node , node .input [1 ], maximum_iterations_name , 1 )
374
+ maximum_iterations_int64 = maximum_iterations_name
368
375
else :
369
- maximum_iterations_name = utils .make_name (node .inputs [1 ].name )
370
- ctx .make_const (maximum_iterations_name , np .array (maximum_iterations , dtype = np .int64 ))
371
- ctx .replace_input (node , node .input [1 ], maximum_iterations_name , 1 )
376
+ cast_inputs = [maximum_iterations_name ]
377
+ attr = {"to" : onnx_pb .TensorProto .INT64 }
378
+ cast_name = node .name + "_cast"
379
+ cast_node = ctx .make_node ("Cast" , cast_inputs , attr , name = cast_name )
380
+ maximum_iterations_int64 = cast_node .output [0 ]
372
381
373
382
cond_name = node .get_attr_str ("cond" )
374
383
cond_graph = find_function (cond_name )
@@ -444,7 +453,7 @@ def version_7(cls, ctx, node, **kwargs):
444
453
output_names = output_names [2 :]
445
454
446
455
branches = {"body" : body }
447
- loop_node = ctx .make_node ("Loop" , [maximum_iterations_name , cond_outputs [0 ]] + loop_vars ,
456
+ loop_node = ctx .make_node ("Loop" , [maximum_iterations_int64 , cond_outputs [0 ]] + loop_vars ,
448
457
output_count = len (output_shapes ), name = node .name + "_loop" ,
449
458
shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
450
459
branches = branches )
0 commit comments