@@ -366,15 +366,7 @@ def version_7(cls, ctx, node, **kwargs):
366
366
# consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
367
367
# if maximum_iterations is not const,should add an cast node(cast to int64)
368
368
maximum_iterations_name = node .input [1 ]
369
- cast_mark = False
370
- if node .inputs [1 ].type != "Const" :
371
- cast_mark = True
372
- if cast_mark :
373
- cast_inputs = [maximum_iterations_name ]
374
- attr = {"to" : onnx_pb .TensorProto .INT64 }
375
- cast_name = node .name + "_cast"
376
- cast_node = ctx .make_node ("Cast" , cast_inputs , attr , name = cast_name )
377
- else :
369
+ if node .inputs [1 ].is_const ():
378
370
maximum_iterations = node .inputs [1 ].get_tensor_value ()
379
371
if maximum_iterations == - 1 :
380
372
maximum_iterations = np .iinfo (np .int64 ).max
@@ -386,6 +378,13 @@ def version_7(cls, ctx, node, **kwargs):
386
378
maximum_iterations_name = utils .make_name (node .inputs [1 ].name )
387
379
ctx .make_const (maximum_iterations_name , np .array (maximum_iterations , dtype = np .int64 ))
388
380
ctx .replace_input (node , node .input [1 ], maximum_iterations_name , 1 )
381
+ maximum_iterations_int64 = maximum_iterations_name
382
+ else :
383
+ cast_inputs = [maximum_iterations_name ]
384
+ attr = {"to" : onnx_pb .TensorProto .INT64 }
385
+ cast_name = node .name + "_cast"
386
+ cast_node = ctx .make_node ("Cast" , cast_inputs , attr , name = cast_name )
387
+ maximum_iterations_int64 = cast_node .output [0 ]
389
388
390
389
cond_name = node .get_attr_str ("cond" )
391
390
cond_graph = find_function (cond_name )
@@ -461,16 +460,10 @@ def version_7(cls, ctx, node, **kwargs):
461
460
output_names = output_names [2 :]
462
461
463
462
branches = {"body" : body }
464
- if cast_mark :
465
- loop_node = ctx .make_node ("Loop" , [cast_node .output [0 ], cond_outputs [0 ]] + loop_vars ,
466
- output_count = len (output_shapes ), name = node .name + "_loop" ,
467
- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
468
- branches = branches )
469
- else :
470
- loop_node = ctx .make_node ("Loop" , [maximum_iterations_name , cond_outputs [0 ]] + loop_vars ,
471
- output_count = len (output_shapes ), name = node .name + "_loop" ,
472
- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
473
- branches = branches )
463
+ loop_node = ctx .make_node ("Loop" , [maximum_iterations_int64 , cond_outputs [0 ]] + loop_vars ,
464
+ output_count = len (output_shapes ), name = node .name + "_loop" ,
465
+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
466
+ branches = branches )
474
467
475
468
output_map = dict (zip (output_names , loop_node .output ))
476
469
0 commit comments