@@ -58,23 +58,38 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
5858 else :
5959 from torch .onnx .symbolic_opset9 import _cast_Long
6060 from torch .onnx .symbolic_opset11 import add , select
61- batch_indices = _cast_Long (
62- g ,
63- g .op (
64- 'Squeeze' ,
65- select (
66- g , rois , 1 ,
67- g .op (
68- 'Constant' ,
69- value_t = torch .tensor ([0 ], dtype = torch .long ))),
70- axes_i = [1 ]), False )
61+ ir_cfg = get_ir_config (ctx .cfg )
62+ opset_version = ir_cfg .get ('opset_version' , 11 )
63+ if opset_version < 13 :
64+ batch_indices = _cast_Long (
65+ g ,
66+ g .op (
67+ 'Squeeze' ,
68+ select (
69+ g , rois , 1 ,
70+ g .op (
71+ 'Constant' ,
72+ value_t = torch .tensor ([0 ], dtype = torch .long ))),
73+ axes_i = [1 ]), False )
74+ else :
75+ axes = g .op (
76+ 'Constant' , value_t = torch .tensor ([1 ], dtype = torch .long ))
77+ batch_indices = _cast_Long (
78+ g ,
79+ g .op (
80+ 'Squeeze' ,
81+ select (
82+ g , rois , 1 ,
83+ g .op (
84+ 'Constant' ,
85+ value_t = torch .tensor ([0 ], dtype = torch .long ))),
86+ axes ), False )
7187 rois = select (
7288 g , rois , 1 ,
7389 g .op (
7490 'Constant' ,
7591 value_t = torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .long )))
76- ir_cfg = get_ir_config (ctx .cfg )
77- opset_version = ir_cfg .get ('opset_version' , 11 )
92+
7893 if opset_version < 16 :
7994 # preprocess rois to make compatible with opset 16-
8095 # as for opset 16+, `aligned` get implemented inside onnxruntime.
@@ -96,6 +111,10 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
96111 sampling_ratio_i = sampling_ratio ,
97112 mode_s = pool_mode )
98113 else :
114+ if aligned :
115+ coordinate_transformation_mode = 'half_pixel'
116+ else :
117+ coordinate_transformation_mode = 'output_half_pixel'
99118 return g .op (
100119 'RoiAlign' ,
101120 input ,
@@ -106,4 +125,5 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
106125 spatial_scale_f = spatial_scale ,
107126 sampling_ratio_i = sampling_ratio ,
108127 mode_s = pool_mode ,
109- aligned_i = aligned )
128+ coordinate_transformation_mode_s = coordinate_transformation_mode
129+ )
0 commit comments