@@ -27,28 +27,25 @@ T bilinear_interpolate(
27
27
return 0 ;
28
28
}
29
29
30
- if (y <= 0 )
31
- y = 0 ;
32
- if (x <= 0 )
33
- x = 0 ;
30
+ y = std::max (T (0 ), y);
31
+ x = std::max (T (0 ), x);
34
32
35
33
int y_low = (int )y;
36
34
int x_low = (int )x;
37
35
int y_high;
38
36
int x_high;
39
37
40
- if (y_low >= height - 1 ) {
41
- y_high = y_low = height - 1 ;
38
+ y_low = std::min (height - 1 , y_low);
39
+ x_low = std::min (width - 1 , x_low);
40
+ y_high = std::min (y_low + 1 , height - 1 );
41
+ x_high = std::min (x_low + 1 , width - 1 );
42
+
43
+ if (y_low == height - 1 ) {
42
44
y = (T)y_low;
43
- } else {
44
- y_high = y_low + 1 ;
45
45
}
46
46
47
- if (x_low >= width - 1 ) {
48
- x_high = x_low = width - 1 ;
47
+ if (x_low == width - 1 ) {
49
48
x = (T)x_low;
50
- } else {
51
- x_high = x_low + 1 ;
52
49
}
53
50
54
51
T ly = y - y_low;
@@ -67,24 +64,39 @@ T bilinear_interpolate(
67
64
return val;
68
65
}
69
66
template <typename T>
70
- struct RoiAlignForwardKernel {
67
+ struct RoiAlignForwardKernel : public __SYCL_KER_CONFIG_CONVENTION__ {
71
68
void operator ()(sycl::nd_item<1 > item) const {
72
- XPU_KERNEL_LOOP (item, index, nthreads_) {
73
- // (n, c, ph, pw) is an element in the pooled output
74
- int pw = index % pooled_width_;
75
- int ph = (index / pooled_width_) % pooled_height_;
76
- int c = (index / pooled_width_ / pooled_height_) % channels_;
77
- int n = index / pooled_width_ / pooled_height_ / channels_;
78
-
79
- const T* offset_rois = rois_ + n * 5 ;
80
- int roi_batch_ind = offset_rois[0 ];
69
+ // each roi will have 5 values, batch_idx,x1,y1,x2,y2
70
+ constexpr int roi_size = 5 ;
71
+ auto wg = item.get_group (0 );
72
+ int n = wg / wgs_per_roi_;
73
+ int output_index_on_batch_n =
74
+ (wg - n * wgs_per_roi_) * item.get_local_range (0 ) +
75
+ item.get_local_id (0 );
76
+ const T* current_roi = rois_ + n * roi_size;
77
+ if (item.get_local_id (0 ) == 0 ) {
78
+ cached_roi_[0 ] = current_roi[0 ];
81
79
82
80
// Do not using rounding; this implementation detail is critical
83
81
T offset = aligned_ ? (T)0.5 : (T)0.0 ;
84
- T roi_start_w = offset_rois[1 ] * spatial_scale_ - offset;
85
- T roi_start_h = offset_rois[2 ] * spatial_scale_ - offset;
86
- T roi_end_w = offset_rois[3 ] * spatial_scale_ - offset;
87
- T roi_end_h = offset_rois[4 ] * spatial_scale_ - offset;
82
+ cached_roi_[1 ] = current_roi[1 ] * spatial_scale_ - offset;
83
+ cached_roi_[2 ] = current_roi[2 ] * spatial_scale_ - offset;
84
+ cached_roi_[3 ] = current_roi[3 ] * spatial_scale_ - offset;
85
+ cached_roi_[4 ] = current_roi[4 ] * spatial_scale_ - offset;
86
+ }
87
+ item.barrier (sycl_local_fence);
88
+
89
+ if (output_index_on_batch_n < items_per_roi_) {
90
+ int pw = output_index_on_batch_n % pooled_width_;
91
+ int ph = (output_index_on_batch_n / pooled_width_) % pooled_height_;
92
+ int c = (output_index_on_batch_n / pooled_width_ / pooled_height_) %
93
+ channels_;
94
+
95
+ int roi_batch_ind = cached_roi_[0 ];
96
+ T roi_start_w = cached_roi_[1 ];
97
+ T roi_start_h = cached_roi_[2 ];
98
+ T roi_end_w = cached_roi_[3 ];
99
+ T roi_end_h = cached_roi_[4 ];
88
100
89
101
T roi_width = roi_end_w - roi_start_w;
90
102
T roi_height = roi_end_h - roi_start_h;
@@ -125,20 +137,26 @@ struct RoiAlignForwardKernel {
125
137
static_cast <T>(ix + .5f ) * bin_size_w /
126
138
static_cast <T>(roi_bin_grid_w);
127
139
128
- T val =
129
- bilinear_interpolate (offset_input, height_, width_, y, x, index);
140
+ T val = bilinear_interpolate (
141
+ offset_input,
142
+ height_,
143
+ width_,
144
+ y,
145
+ x,
146
+ output_index_on_batch_n + n * items_per_roi_);
130
147
output_val += val;
131
148
}
132
149
}
133
150
output_val /= count;
134
151
135
- output_[index ] = output_val;
152
+ output_[output_index_on_batch_n + n * items_per_roi_ ] = output_val;
136
153
}
137
154
}
138
155
RoiAlignForwardKernel (
139
- int nthreads,
140
156
const T* input,
141
157
const T spatial_scale,
158
+ int items_per_rois,
159
+ int wgs_per_roi,
142
160
int channels,
143
161
int height,
144
162
int width,
@@ -148,9 +166,10 @@ struct RoiAlignForwardKernel {
148
166
bool aligned,
149
167
const T* rois,
150
168
T* output)
151
- : nthreads_(nthreads),
152
- input_ (input),
169
+ : input_(input),
153
170
spatial_scale_ (spatial_scale),
171
+ items_per_roi_(items_per_rois),
172
+ wgs_per_roi_(wgs_per_roi),
154
173
channels_(channels),
155
174
height_(height),
156
175
width_(width),
@@ -160,20 +179,26 @@ struct RoiAlignForwardKernel {
160
179
aligned_(aligned),
161
180
rois_(rois),
162
181
output_(output) {}
182
+ void sycl_ker_config_convention (sycl::handler& cgh) {
183
+ // each roi will have 5 values, batch_idx,x1,y1,x2,y2
184
+ cached_roi_ = sycl_local_acc_t <T>(5 , cgh);
185
+ }
163
186
164
187
private:
165
- int nthreads_;
166
188
const T* input_;
167
189
const T spatial_scale_;
168
- int channels_;
169
- int height_;
170
- int width_;
171
- int pooled_height_;
172
- int pooled_width_;
173
- int sampling_ratio_;
174
- bool aligned_;
190
+ const int items_per_roi_;
191
+ const int wgs_per_roi_;
192
+ const int channels_;
193
+ const int height_;
194
+ const int width_;
195
+ const int pooled_height_;
196
+ const int pooled_width_;
197
+ const int sampling_ratio_;
198
+ const bool aligned_;
175
199
const T* rois_;
176
200
T* output_;
201
+ sycl_local_acc_t <T> cached_roi_;
177
202
};
178
203
179
204
template <typename T>
@@ -415,11 +440,7 @@ Tensor roi_align_kernel(
415
440
416
441
at::Tensor output = at::zeros (
417
442
{num_rois, channels, pooled_height, pooled_width}, input.options ());
418
-
419
443
auto output_size = num_rois * pooled_height * pooled_width * channels;
420
- int64_t global_range =
421
- ceil_div (static_cast <int64_t >(output_size), static_cast <int64_t >(512 ));
422
- int64_t local_range = 512 ;
423
444
424
445
if (output.numel () == 0 ) {
425
446
return output;
@@ -433,10 +454,22 @@ Tensor roi_align_kernel(
433
454
input.scalar_type (),
434
455
" roi_align_forward_kernel_xpu" ,
435
456
[&] {
457
+ int64_t local_range =
458
+ syclMaxWorkGroupSize<RoiAlignForwardKernel<scalar_t >>();
459
+ int items_per_roi = pooled_height * pooled_width * channels;
460
+ if (items_per_roi < local_range) {
461
+ constexpr int simd_len = 32 ;
462
+ local_range = std::min (
463
+ local_range,
464
+ int64_t (items_per_roi + simd_len - 1 ) / simd_len * simd_len);
465
+ }
466
+ int wgs_per_roi = (items_per_roi + local_range - 1 ) / local_range;
467
+ int64_t global_range = wgs_per_roi * num_rois;
436
468
auto kfn = RoiAlignForwardKernel<scalar_t >(
437
- output_size,
438
469
input_.data_ptr <scalar_t >(),
439
470
spatial_scale,
471
+ items_per_roi,
472
+ wgs_per_roi,
440
473
channels,
441
474
height,
442
475
width,
0 commit comments