Skip to content

Commit 337deed

Browse files
jianyizhtoyxu
andauthored
Optimize roi_align on BMG (#1698)
For input [1, 2048, 50, 75], rois [1000,5], roi align takes 4.7 ms on PVC but 75 ms on BMG. Each roi will have 2048xoutput_hxoutput_w work items reading the same value from LLC, and it's very slow on BMG. After put them into shared local memory, PVC takes 4.0ms, BMG reaches 7.5ms. I also removed some if else branching by min/max. I also fix a code style issue. --------- Co-authored-by: Yutao Xu <[email protected]>
1 parent 88e09c5 commit 337deed

File tree

2 files changed

+89
-52
lines changed

2 files changed

+89
-52
lines changed

src/ATen/native/xpu/sycl/RoiAlignKernels.cpp

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,25 @@ T bilinear_interpolate(
2727
return 0;
2828
}
2929

30-
if (y <= 0)
31-
y = 0;
32-
if (x <= 0)
33-
x = 0;
30+
y = std::max(T(0), y);
31+
x = std::max(T(0), x);
3432

3533
int y_low = (int)y;
3634
int x_low = (int)x;
3735
int y_high;
3836
int x_high;
3937

40-
if (y_low >= height - 1) {
41-
y_high = y_low = height - 1;
38+
y_low = std::min(height - 1, y_low);
39+
x_low = std::min(width - 1, x_low);
40+
y_high = std::min(y_low + 1, height - 1);
41+
x_high = std::min(x_low + 1, width - 1);
42+
43+
if (y_low == height - 1) {
4244
y = (T)y_low;
43-
} else {
44-
y_high = y_low + 1;
4545
}
4646

47-
if (x_low >= width - 1) {
48-
x_high = x_low = width - 1;
47+
if (x_low == width - 1) {
4948
x = (T)x_low;
50-
} else {
51-
x_high = x_low + 1;
5249
}
5350

5451
T ly = y - y_low;
@@ -67,24 +64,39 @@ T bilinear_interpolate(
6764
return val;
6865
}
6966
template <typename T>
70-
struct RoiAlignForwardKernel {
67+
struct RoiAlignForwardKernel : public __SYCL_KER_CONFIG_CONVENTION__ {
7168
void operator()(sycl::nd_item<1> item) const {
72-
XPU_KERNEL_LOOP(item, index, nthreads_) {
73-
// (n, c, ph, pw) is an element in the pooled output
74-
int pw = index % pooled_width_;
75-
int ph = (index / pooled_width_) % pooled_height_;
76-
int c = (index / pooled_width_ / pooled_height_) % channels_;
77-
int n = index / pooled_width_ / pooled_height_ / channels_;
78-
79-
const T* offset_rois = rois_ + n * 5;
80-
int roi_batch_ind = offset_rois[0];
69+
// each roi will have 5 values, batch_idx,x1,y1,x2,y2
70+
constexpr int roi_size = 5;
71+
auto wg = item.get_group(0);
72+
int n = wg / wgs_per_roi_;
73+
int output_index_on_batch_n =
74+
(wg - n * wgs_per_roi_) * item.get_local_range(0) +
75+
item.get_local_id(0);
76+
const T* current_roi = rois_ + n * roi_size;
77+
if (item.get_local_id(0) == 0) {
78+
cached_roi_[0] = current_roi[0];
8179

8280
// Do not using rounding; this implementation detail is critical
8381
T offset = aligned_ ? (T)0.5 : (T)0.0;
84-
T roi_start_w = offset_rois[1] * spatial_scale_ - offset;
85-
T roi_start_h = offset_rois[2] * spatial_scale_ - offset;
86-
T roi_end_w = offset_rois[3] * spatial_scale_ - offset;
87-
T roi_end_h = offset_rois[4] * spatial_scale_ - offset;
82+
cached_roi_[1] = current_roi[1] * spatial_scale_ - offset;
83+
cached_roi_[2] = current_roi[2] * spatial_scale_ - offset;
84+
cached_roi_[3] = current_roi[3] * spatial_scale_ - offset;
85+
cached_roi_[4] = current_roi[4] * spatial_scale_ - offset;
86+
}
87+
item.barrier(sycl_local_fence);
88+
89+
if (output_index_on_batch_n < items_per_roi_) {
90+
int pw = output_index_on_batch_n % pooled_width_;
91+
int ph = (output_index_on_batch_n / pooled_width_) % pooled_height_;
92+
int c = (output_index_on_batch_n / pooled_width_ / pooled_height_) %
93+
channels_;
94+
95+
int roi_batch_ind = cached_roi_[0];
96+
T roi_start_w = cached_roi_[1];
97+
T roi_start_h = cached_roi_[2];
98+
T roi_end_w = cached_roi_[3];
99+
T roi_end_h = cached_roi_[4];
88100

89101
T roi_width = roi_end_w - roi_start_w;
90102
T roi_height = roi_end_h - roi_start_h;
@@ -125,20 +137,26 @@ struct RoiAlignForwardKernel {
125137
static_cast<T>(ix + .5f) * bin_size_w /
126138
static_cast<T>(roi_bin_grid_w);
127139

128-
T val =
129-
bilinear_interpolate(offset_input, height_, width_, y, x, index);
140+
T val = bilinear_interpolate(
141+
offset_input,
142+
height_,
143+
width_,
144+
y,
145+
x,
146+
output_index_on_batch_n + n * items_per_roi_);
130147
output_val += val;
131148
}
132149
}
133150
output_val /= count;
134151

135-
output_[index] = output_val;
152+
output_[output_index_on_batch_n + n * items_per_roi_] = output_val;
136153
}
137154
}
138155
RoiAlignForwardKernel(
139-
int nthreads,
140156
const T* input,
141157
const T spatial_scale,
158+
int items_per_rois,
159+
int wgs_per_roi,
142160
int channels,
143161
int height,
144162
int width,
@@ -148,9 +166,10 @@ struct RoiAlignForwardKernel {
148166
bool aligned,
149167
const T* rois,
150168
T* output)
151-
: nthreads_(nthreads),
152-
input_(input),
169+
: input_(input),
153170
spatial_scale_(spatial_scale),
171+
items_per_roi_(items_per_rois),
172+
wgs_per_roi_(wgs_per_roi),
154173
channels_(channels),
155174
height_(height),
156175
width_(width),
@@ -160,20 +179,26 @@ struct RoiAlignForwardKernel {
160179
aligned_(aligned),
161180
rois_(rois),
162181
output_(output) {}
182+
void sycl_ker_config_convention(sycl::handler& cgh) {
183+
// each roi will have 5 values, batch_idx,x1,y1,x2,y2
184+
cached_roi_ = sycl_local_acc_t<T>(5, cgh);
185+
}
163186

164187
private:
165-
int nthreads_;
166188
const T* input_;
167189
const T spatial_scale_;
168-
int channels_;
169-
int height_;
170-
int width_;
171-
int pooled_height_;
172-
int pooled_width_;
173-
int sampling_ratio_;
174-
bool aligned_;
190+
const int items_per_roi_;
191+
const int wgs_per_roi_;
192+
const int channels_;
193+
const int height_;
194+
const int width_;
195+
const int pooled_height_;
196+
const int pooled_width_;
197+
const int sampling_ratio_;
198+
const bool aligned_;
175199
const T* rois_;
176200
T* output_;
201+
sycl_local_acc_t<T> cached_roi_;
177202
};
178203

179204
template <typename T>
@@ -415,11 +440,7 @@ Tensor roi_align_kernel(
415440

416441
at::Tensor output = at::zeros(
417442
{num_rois, channels, pooled_height, pooled_width}, input.options());
418-
419443
auto output_size = num_rois * pooled_height * pooled_width * channels;
420-
int64_t global_range =
421-
ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512));
422-
int64_t local_range = 512;
423444

424445
if (output.numel() == 0) {
425446
return output;
@@ -433,10 +454,22 @@ Tensor roi_align_kernel(
433454
input.scalar_type(),
434455
"roi_align_forward_kernel_xpu",
435456
[&] {
457+
int64_t local_range =
458+
syclMaxWorkGroupSize<RoiAlignForwardKernel<scalar_t>>();
459+
int items_per_roi = pooled_height * pooled_width * channels;
460+
if (items_per_roi < local_range) {
461+
constexpr int simd_len = 32;
462+
local_range = std::min(
463+
local_range,
464+
int64_t(items_per_roi + simd_len - 1) / simd_len * simd_len);
465+
}
466+
int wgs_per_roi = (items_per_roi + local_range - 1) / local_range;
467+
int64_t global_range = wgs_per_roi * num_rois;
436468
auto kfn = RoiAlignForwardKernel<scalar_t>(
437-
output_size,
438469
input_.data_ptr<scalar_t>(),
439470
spatial_scale,
471+
items_per_roi,
472+
wgs_per_roi,
440473
channels,
441474
height,
442475
width,

src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,9 @@ struct UpsampleBilinear2dBackwardNotAlignKernelFunctor {
425425
// scale is 1 if on boundary
426426
distance_w =
427427
distance_w + is_boundary_w * (output_width_ * 2 - distance_w);
428-
bool is_boundary_h =
429-
!((point_h >= output_height_) &&
430-
(point_h <= output_height_ * input_height_ * 2 - output_height_));
428+
bool is_boundary_h = !(
429+
(point_h >= output_height_) &&
430+
(point_h <= output_height_ * input_height_ * 2 - output_height_));
431431
distance_h =
432432
distance_h + is_boundary_h * (output_height_ * 2 - distance_h);
433433
accscalar_t scale =
@@ -606,8 +606,10 @@ void launch_upsample_bilinear2d_backward_kernel(
606606
// TODO: when input 3x3, scale is 1.5, output is 4x4,
607607
// pytorch prefer use 1/1.5, but my implementation treat it as 3/4...
608608
// I also have to skip double because of rounding issues, it will not pass ut
609-
can_optimize = can_optimize && (align_corners || (input_width == (rwidth * output_width) &&
610-
input_height == (rheight * output_height))) &&
609+
can_optimize = can_optimize &&
610+
(align_corners ||
611+
(input_width == (rwidth * output_width) &&
612+
input_height == (rheight * output_height))) &&
611613
!std::is_same<scalar_t, double>::value;
612614
if (can_optimize) {
613615
if (align_corners) {
@@ -790,8 +792,10 @@ void launch_upsample_bilinear2d_backward_nhwc_kernel(
790792
// TODO: when input 3x3, scale is 1.5, output is 4x4,
791793
// pytorch prefer use 1/1.5, but my implementation treat it as 3/4...
792794
// I also have to skip double because of rounding issues, it will not pass ut
793-
can_optimize = can_optimize && (align_corners || (input_width == (rwidth * output_width) &&
794-
input_height == (rheight * output_height))) &&
795+
can_optimize = can_optimize &&
796+
(align_corners ||
797+
(input_width == (rwidth * output_width) &&
798+
input_height == (rheight * output_height))) &&
795799
!std::is_same<scalar_t, double>::value;
796800
if (can_optimize) {
797801
if (align_corners) {

0 commit comments

Comments
 (0)