@@ -315,10 +315,10 @@ def build_dynamic_target_size(ctx, transposed_intput, target_hw):
315
315
A tensor of rank 2 containing [n c nh nw]
316
316
"""
317
317
# We get the first half [n c] of the target shape
318
- shape_of_transposed_input = ctx .make_node ("Shape" , [transposed_intput . output [ 0 ] ])
318
+ shape_of_transposed_input = ctx .make_node ("Shape" , [transposed_intput ])
319
319
first_half_of_shape = GraphBuilder (ctx ).make_slice (
320
320
{"data" : shape_of_transposed_input .output [0 ], "ends" : [2 ], "starts" : [0 ]})
321
- target_size_int64 = ctx .make_node ("Cast" , [target_hw . output [ 0 ] ], attr = {'to' : TensorProto .INT64 })
321
+ target_size_int64 = ctx .make_node ("Cast" , [target_hw ], attr = {'to' : TensorProto .INT64 })
322
322
# We build a tensor containing [n c nh nw]
323
323
final_target_size = ctx .make_node ("Concat" , [first_half_of_shape , target_size_int64 .output [0 ]], {'axis' : 0 })
324
324
return final_target_size
@@ -916,10 +916,10 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
916
916
mode = "nearest" if node .get_attr ("method" ) is not None and node .get_attr (
917
917
"method" ).s == b"nearest" else "linear"
918
918
extrapolation_value = float (node .get_attr ("extrapolation_value" , "0" ).f )
919
- input_x = node .inputs [0 ]
920
- boxes = node .inputs [1 ]
921
- box_ind = node .inputs [2 ]
922
- crop_size = node .inputs [3 ]
919
+ input_x = node .input [0 ]
920
+ boxes = node .input [1 ]
921
+ box_ind = node .input [2 ]
922
+ crop_size = node .input [3 ]
923
923
trip_name = utils .make_name (node .name + "_i" )
924
924
cond_name = utils .make_name (node .name + "_cond" )
925
925
cond_out_name = utils .make_name (node .name + "cond_out" )
@@ -932,9 +932,9 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
932
932
const_one = g .make_const (utils .make_name (node .name + "_const_one" ), np .array ([1 ], dtype = np .int32 ))
933
933
const_one_long = g .make_const (utils .make_name (node .name + "_const_one_long" ), np .array ([1 ], dtype = np .int64 ))
934
934
index_end = g .make_node ("Add" , [trip_name , const_one_long .output [0 ]])
935
- box_index_from = g .make_node ("Slice" , [box_ind . output [ 0 ] , trip_name , index_end .output [0 ]], name = "Slice_a" )
935
+ box_index_from = g .make_node ("Slice" , [box_ind , trip_name , index_end .output [0 ]], name = "Slice_a" )
936
936
box_index_to = g .make_node ("Add" , [box_index_from .output [0 ], const_one .output [0 ]])
937
- target_x = g .make_node ("Slice" , [input_x . output [ 0 ] , box_index_from .output [0 ], box_index_to .output [0 ],
937
+ target_x = g .make_node ("Slice" , [input_x , box_index_from .output [0 ], box_index_to .output [0 ],
938
938
const_zero .output [0 ]], name = "Slice_b" )
939
939
transposed_x = g .make_node ("Transpose" , [target_x .output [0 ]], attr = {'perm' : constants .NHWC_TO_NCHW })
940
940
const_zero_zero = g .make_const (utils .make_name (node .name + "_const_zero_zero" ),
@@ -943,15 +943,15 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
943
943
np .array ([1 , 1 ], dtype = np .float32 ))
944
944
const_four = g .make_const (utils .make_name (node .name + "_const_four" ), np .array ([4 ], dtype = np .int64 ))
945
945
const_empty_float = g .make_const (utils .make_name ("const_empty_float" ), np .array ([], dtype = np .float32 ))
946
- box = g .make_node ("Slice" , [boxes . output [ 0 ] , trip_name , index_end .output [0 ], const_zero_long .output [0 ]],
946
+ box = g .make_node ("Slice" , [boxes , trip_name , index_end .output [0 ], const_zero_long .output [0 ]],
947
947
name = "Slice_c" )
948
948
roi_raw = g .make_node ("Reshape" , [box .output [0 ], const_four .output [0 ]])
949
949
roi_raw_first_half = GraphBuilder (g ).make_slice ({"data" : roi_raw .output [0 ], "ends" : [2 ], "starts" : [0 ]})
950
950
roi_raw_second_half = GraphBuilder (g ).make_slice ({"data" : roi_raw .output [0 ], "ends" : [4 ], "starts" : [2 ]})
951
951
roi_concat_1 = g .make_node ("Concat" , [const_zero_zero .output [0 ], roi_raw_first_half ], attr = {'axis' : 0 })
952
952
roi_concat_2 = g .make_node ("Concat" , [const_one_one .output [0 ], roi_raw_second_half ], attr = {'axis' : 0 })
953
953
final_roi = g .make_node ("Concat" , [roi_concat_1 .output [0 ], roi_concat_2 .output [0 ]], attr = {'axis' : 0 })
954
- final_crop_size = build_dynamic_target_size (g , transposed_x , crop_size )
954
+ final_crop_size = build_dynamic_target_size (g , transposed_x . output [ 0 ] , crop_size )
955
955
resized_x = g .make_node ("Resize" , [transposed_x .output [0 ], final_roi .output [0 ], const_empty_float .output [0 ],
956
956
final_crop_size .output [0 ]],
957
957
attr = {"mode" : mode , "extrapolation_value" : extrapolation_value ,
@@ -961,7 +961,7 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
961
961
g .make_node ("Identity" , [cond_name ], outputs = [cond_out_name ])
962
962
g .add_graph_output (cond_out_name , TensorProto .BOOL , [])
963
963
g .add_graph_output (squeeze_x .output [0 ], ctx .get_dtype (node .input [0 ]), [- 1 , - 1 , - 1 ])
964
- trip_node = ctx .make_node ("Size" , [box_ind . output [ 0 ] ])
964
+ trip_node = ctx .make_node ("Size" , [box_ind ])
965
965
cond_const = ctx .make_const (utils .make_name ("cond" ), np .ones ((), dtype = np .bool ))
966
966
ctx .remove_node (node .name )
967
967
branches = {"body" : g }
@@ -1070,7 +1070,7 @@ def _convert_since_9(cls, ctx, node, op_type, use_target_size=False):
1070
1070
# because onnxruntime only supports to scale the last two dims so transpose is inserted
1071
1071
input_nchw = ctx .make_node ("Transpose" , [node .input [0 ]], {"perm" : constants .NHWC_TO_NCHW })
1072
1072
if use_target_size :
1073
- final_target_size = build_dynamic_target_size (ctx , input_nchw , node .inputs [1 ])
1073
+ final_target_size = build_dynamic_target_size (ctx , input_nchw . output [ 0 ] , node .input [1 ])
1074
1074
roi = ctx .make_const (utils .make_name ("roi" ), np .array ([]).astype (np .float32 ))
1075
1075
const_empty_float = ctx .make_const (utils .make_name ("const_empty_float" ), np .array ([], dtype = np .float32 ))
1076
1076
resize_inputs = [
0 commit comments