@@ -362,20 +362,30 @@ def version_7(cls, ctx, node, **kwargs):
362
362
# may be removed from output_names below
363
363
output_names = node .output .copy ()
364
364
365
- # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other consumers,
366
- # modify it in place. Otherwise, make a new const node and leave the original unchanged.
365
+ # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other
366
+ # consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
367
+ # if maximum_iterations is not const,should add an cast node(cast to int64)
367
368
maximum_iterations_name = node .input [1 ]
368
- maximum_iterations = node .inputs [1 ].get_tensor_value ()
369
- if maximum_iterations == - 1 :
370
- maximum_iterations = np .iinfo (np .int64 ).max
371
- consumers = ctx .find_output_consumers (maximum_iterations_name )
372
- external_consumers = [c for c in consumers if c != node and c .type != 'TensorListReserve' ]
373
- if len (external_consumers ) == 0 :
374
- ctx .remove_node (node .inputs [1 ].name )
369
+ cast_mark = False
370
+ if node .inputs [1 ].type != "Const" :
371
+ cast_mark = True
372
+ if cast_mark :
373
+ cast_inputs = [maximum_iterations_name ]
374
+ attr = {"to" : onnx_pb .TensorProto .INT64 }
375
+ cast_name = node .name + "_cast"
376
+ cast_node = ctx .make_node ("Cast" , cast_inputs , attr , name = cast_name )
375
377
else :
376
- maximum_iterations_name = utils .make_name (node .inputs [1 ].name )
377
- ctx .make_const (maximum_iterations_name , np .array (maximum_iterations , dtype = np .int64 ))
378
- ctx .replace_input (node , node .input [1 ], maximum_iterations_name , 1 )
378
+ maximum_iterations = node .inputs [1 ].get_tensor_value ()
379
+ if maximum_iterations == - 1 :
380
+ maximum_iterations = np .iinfo (np .int64 ).max
381
+ consumers = ctx .find_output_consumers (maximum_iterations_name )
382
+ external_consumers = [c for c in consumers if c != node and c .type != 'TensorListReserve' ]
383
+ if len (external_consumers ) == 0 :
384
+ ctx .remove_node (node .inputs [1 ].name )
385
+ else :
386
+ maximum_iterations_name = utils .make_name (node .inputs [1 ].name )
387
+ ctx .make_const (maximum_iterations_name , np .array (maximum_iterations , dtype = np .int64 ))
388
+ ctx .replace_input (node , node .input [1 ], maximum_iterations_name , 1 )
379
389
380
390
cond_name = node .get_attr_str ("cond" )
381
391
cond_graph = find_function (cond_name )
@@ -451,10 +461,16 @@ def version_7(cls, ctx, node, **kwargs):
451
461
output_names = output_names [2 :]
452
462
453
463
branches = {"body" : body }
454
- loop_node = ctx .make_node ("Loop" , [maximum_iterations_name , cond_outputs [0 ]] + loop_vars ,
455
- output_count = len (output_shapes ), name = node .name + "_loop" ,
456
- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
457
- branches = branches )
464
+ if cast_mark :
465
+ loop_node = ctx .make_node ("Loop" , [cast_node .output [0 ], cond_outputs [0 ]] + loop_vars ,
466
+ output_count = len (output_shapes ), name = node .name + "_loop" ,
467
+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
468
+ branches = branches )
469
+ else :
470
+ loop_node = ctx .make_node ("Loop" , [maximum_iterations_name , cond_outputs [0 ]] + loop_vars ,
471
+ output_count = len (output_shapes ), name = node .name + "_loop" ,
472
+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
473
+ branches = branches )
458
474
459
475
output_map = dict (zip (output_names , loop_node .output ))
460
476
0 commit comments