@@ -318,9 +318,10 @@ def build_dynamic_target_size(ctx, transposed_intput, target_hw):
318
318
shape_of_transposed_input = ctx .make_node ("Shape" , [transposed_intput ])
319
319
first_half_of_shape = GraphBuilder (ctx ).make_slice (
320
320
{"data" : shape_of_transposed_input .output [0 ], "ends" : [2 ], "starts" : [0 ]})
321
- target_size_int64 = ctx .make_node ("Cast" , [target_hw ], attr = {'to' : TensorProto .INT64 })
321
+ if ctx .get_dtype (target_hw ) != TensorProto .INT64 :
322
+ target_hw = ctx .make_node ("Cast" , [target_hw ], attr = {'to' : TensorProto .INT64 }).output [0 ]
322
323
# We build a tensor containing [n c nh nw]
323
- final_target_size = ctx .make_node ("Concat" , [first_half_of_shape , target_size_int64 . output [ 0 ] ], {'axis' : 0 })
324
+ final_target_size = ctx .make_node ("Concat" , [first_half_of_shape , target_hw ], {'axis' : 0 })
324
325
return final_target_size
325
326
326
327
@@ -1183,9 +1184,13 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
1183
1184
"method" ).s == b"nearest" else "linear"
1184
1185
extrapolation_value = float (node .get_attr ("extrapolation_value" , "0" ).f )
1185
1186
input_x = node .input [0 ]
1187
+ x_shape = ctx .make_node ("Shape" , [input_x ]).output [0 ]
1188
+ num_channels = GraphBuilder (ctx ).make_slice ({"data" : x_shape , "starts" : [3 ], "ends" : [4 ], "axes" : [0 ]})
1186
1189
boxes = node .input [1 ]
1187
1190
box_ind = node .input [2 ]
1188
1191
crop_size = node .input [3 ]
1192
+ if ctx .get_dtype (crop_size ) != TensorProto .INT64 :
1193
+ crop_size = ctx .make_node ("Cast" , [crop_size ], attr = {'to' : TensorProto .INT64 }).output [0 ]
1189
1194
trip_name = utils .make_name (node .name + "_i" )
1190
1195
cond_name = utils .make_name (node .name + "_cond" )
1191
1196
cond_out_name = utils .make_name (node .name + "cond_out" )
@@ -1233,6 +1238,10 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
1233
1238
branches = {"body" : g }
1234
1239
inner_loop = ctx .make_node ("Loop" , [trip_node .output [0 ], cond_const .output [0 ]], name = node .name ,
1235
1240
outputs = node .output , branches = branches )
1241
+ const_neg_one = ctx .make_const (utils .make_name ("const_neg_one" ), np .array ([- 1 ], np .int64 )).output [0 ]
1242
+ final_shape = ctx .make_node ("Concat" , [const_neg_one , crop_size , num_channels ], attr = {'axis' : 0 }).output [0 ]
1243
+ # This reshape fixes the case when there are no iterations and the scan output is empty.
1244
+ ctx .insert_new_node_on_output ("Reshape" , inner_loop .output [0 ], inputs = [inner_loop .output [0 ], final_shape ])
1236
1245
1237
1246
@classmethod
1238
1247
def version_11 (cls , ctx , node , ** kwargs ):
0 commit comments