Skip to content

Commit c091232

Browse files
jianyizhCopilotchunhuanMeng
authored
add vectorization path on maxpool forward channel last (#1883)
Part 1 of #1861 tested on shapes from alexnet training on BMG, 831719 Scoreboard stalls decrease to 497,098. instruction fetch and distance stall also get better. | shape | device | before opt | after opt | |---|---|---|---| | [4096, 64, 55, 55] | pvc | 8.02 ms | 5.44 ms | | [4096, 64, 55, 55] | bmg | 12.45 ms | 8.89 ms | | [4096, 192, 27, 27] | pvc | 5.72 ms | 3.85 ms | | [4096, 192, 27, 27] | bmg | 9.00 ms | 5.06 ms | | [4096, 256, 13, 13] | pvc | 1.68 ms | 1.12 ms | | [4096, 256, 13, 13] | bmg | 2.83 ms | 1.35 ms | --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: chunhuanMeng <[email protected]>
1 parent 83a1555 commit c091232

File tree

1 file changed

+271
-3
lines changed

1 file changed

+271
-3
lines changed

src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp

Lines changed: 271 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
#include <ATen/native/xpu/sycl/Atomics.h>
1313
#include <ATen/native/xpu/sycl/BatchKernel.h>
14+
#include <ATen/native/xpu/sycl/MemoryAccess.h>
1415
#include <ATen/native/xpu/sycl/NumericLimits.h>
16+
1517
#include <comm/Runtime.h>
1618
#include <comm/SYCLHelpers.h>
1719

@@ -151,6 +153,119 @@ struct MaxPool2dKernelFunctor {
151153
BatchKernelConfig cfg_;
152154
};
153155

156+
template <typename scalar_t, typename vec_t, int vec_size>
157+
struct MaxPool2dChannelLastVec {
158+
void operator()(sycl::nd_item<1> item) const {
159+
for (auto outputIndex = item.get_global_linear_id();
160+
outputIndex < numBatch_ * stride_ / vec_size;
161+
outputIndex += item.get_local_range(0) * item.get_group_range(0)) {
162+
int batch = outputIndex / (stride_ / vec_size);
163+
int plane, outputH, outputW;
164+
int64_t load_offset, store_offset;
165+
plane = outputIndex % (numPlane_ / vec_size);
166+
outputH =
167+
outputIndex / (numPlane_ / vec_size) / outputSizeW_ % outputSizeH_;
168+
outputW = outputIndex / (numPlane_ / vec_size) % outputSizeW_;
169+
store_offset = outputIndex;
170+
171+
vec_t maxVal_vec;
172+
#pragma unroll
173+
for (int i = 0; i < vec_size; i++) {
174+
maxVal_vec[i] = at::numeric_limits<scalar_t>::lower_bound();
175+
}
176+
int64_t maxIndex[vec_size];
177+
for (int i = 0; i < vec_size; i++) {
178+
maxIndex[i] = int64_t(-1);
179+
}
180+
int StartH = outputH * dH_ - padH_;
181+
int StartW = outputW * dW_ - padW_;
182+
int EndH = std::min(StartH + (kH_ - 1) * dilationH_ + 1, inputSizeH_);
183+
int EndW = std::min(StartW + (kW_ - 1) * dilationW_ + 1, inputSizeW_);
184+
while (StartH < 0)
185+
StartH += dilationH_;
186+
while (StartW < 0)
187+
StartW += dilationW_;
188+
for (int h = StartH; h < EndH; h += dilationH_) {
189+
for (int w = StartW; w < EndW; w += dilationW_) {
190+
load_offset = batch * inputSizeH_*inputSizeW_*numPlane_ / vec_size + plane +
191+
h * inputSizeW_ * numPlane_ / vec_size + w * numPlane_ / vec_size;
192+
vec_t val_vec = input_vec_[load_offset];
193+
#pragma unroll
194+
for (int i = 0; i < vec_size; i++) {
195+
if ((static_cast<scalar_t>(val_vec[i]) > maxVal_vec[i]) ||
196+
at::_isnan(val_vec[i])) {
197+
maxIndex[i] = h * inputSizeW_ + w;
198+
maxVal_vec[i] = static_cast<scalar_t>(val_vec[i]);
199+
}
200+
}
201+
}
202+
}
203+
#pragma unroll
204+
for (int i = 0; i < vec_size; i++) {
205+
indices_[store_offset * vec_size + i] = maxIndex[i];
206+
}
207+
output_vec_[store_offset] = maxVal_vec;
208+
}
209+
}
210+
MaxPool2dChannelLastVec(
211+
vec_t* output_vec,
212+
int64_t* indices,
213+
const vec_t* input_vec,
214+
int numBatch,
215+
int numPlane,
216+
int inputSizeH,
217+
int inputSizeW,
218+
int outputSizeH,
219+
int outputSizeW,
220+
int kH,
221+
int kW,
222+
int dH,
223+
int dW,
224+
int padH,
225+
int padW,
226+
int dilationH,
227+
int dilationW,
228+
int stride)
229+
: output_vec_(output_vec),
230+
indices_(indices),
231+
input_vec_(input_vec),
232+
numBatch_(numBatch),
233+
numPlane_(numPlane),
234+
inputSizeH_(inputSizeH),
235+
inputSizeW_(inputSizeW),
236+
outputSizeH_(outputSizeH),
237+
outputSizeW_(outputSizeW),
238+
kH_(kH),
239+
kW_(kW),
240+
dH_(dH),
241+
dW_(dW),
242+
padH_(padH),
243+
padW_(padW),
244+
dilationH_(dilationH),
245+
dilationW_(dilationW),
246+
stride_(stride) {}
247+
248+
private:
249+
vec_t* output_vec_;
250+
int64_t* indices_;
251+
const vec_t* input_vec_;
252+
int numBatch_;
253+
int numPlane_;
254+
int inputSizeH_;
255+
int inputSizeW_;
256+
int outputSizeH_;
257+
int outputSizeW_;
258+
int kH_;
259+
int kW_;
260+
int dH_;
261+
int dW_;
262+
int padH_;
263+
int padW_;
264+
int dilationH_;
265+
int dilationW_;
266+
int stride_;
267+
};
268+
154269
template <typename scalar_t, bool is_channels_last>
155270
struct MaxPool2dBackwardKernelFunctor {
156271
void operator()(sycl::nd_item<2> item) const {
@@ -349,6 +464,56 @@ struct MaxPool2dBackwardDeterministicKernelFunctor {
349464
BatchKernelConfig cfg_;
350465
};
351466

467+
#define LAUNCH_MAXPOOL_CHANNEL_LAST_VEC( \
468+
scalar_t, \
469+
vec_size, \
470+
num_wg, \
471+
wg_size, \
472+
queue, \
473+
output, \
474+
indices, \
475+
input, \
476+
numBatch, \
477+
numPlane, \
478+
inputSizeH, \
479+
inputSizeW, \
480+
outputSizeH, \
481+
outputSizeW, \
482+
kH, \
483+
kW, \
484+
dH, \
485+
dW, \
486+
padH, \
487+
padW, \
488+
dilationH, \
489+
dilationW, \
490+
stride) \
491+
{ \
492+
using vec_t = memory::aligned_vector<scalar_t, vec_size>; \
493+
vec_t* output_vec = reinterpret_cast<vec_t*>(output); \
494+
const vec_t* input_vec = reinterpret_cast<const vec_t*>(input); \
495+
auto kfn = MaxPool2dChannelLastVec<scalar_t, vec_t, vec_size>( \
496+
output_vec, \
497+
indices, \
498+
input_vec, \
499+
numBatch, \
500+
numPlane, \
501+
inputSizeH, \
502+
inputSizeW, \
503+
outputSizeH, \
504+
outputSizeW, \
505+
kH, \
506+
kW, \
507+
dH, \
508+
dW, \
509+
padH, \
510+
padW, \
511+
dilationH, \
512+
dilationW, \
513+
stride); \
514+
sycl_kernel_submit(num_wg * wg_size, wg_size, queue, kfn); \
515+
}
516+
352517
template <typename scalar_t, bool is_channels_last>
353518
void launch_max_pool2d_kernel(
354519
scalar_t* output,
@@ -368,11 +533,114 @@ void launch_max_pool2d_kernel(
368533
int padW,
369534
int dilationH,
370535
int dilationW) {
371-
using KernelClass = MaxPool2dKernelFunctor<scalar_t, is_channels_last>;
372-
373536
auto& queue = at::xpu::getCurrentSYCLQueue();
374537
int outputSize = numBatch * numPlane * outputSizeH * outputSizeW;
375538
int stride = numPlane * outputSizeH * outputSizeW;
539+
int vec_size = 1;
540+
int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU();
541+
int num_sub_wg;
542+
auto wg_size = syclDeviceMaxWorkGroupSize();
543+
int64_t num_wg;
544+
if constexpr (is_channels_last) {
545+
for (vec_size =
546+
std::min(8, memory::can_vectorize_up_to<scalar_t>((char*)input));
547+
vec_size >= 1;
548+
vec_size /= 2) {
549+
if (numPlane % vec_size != 0) {
550+
continue;
551+
}
552+
num_sub_wg = outputSize / vec_size / syclMaxSubGroupSize();
553+
if (2 * num_sub_wg > thread_slots) {
554+
int total_thread = outputSize / vec_size;
555+
num_wg = (total_thread + wg_size - 1) / wg_size;
556+
break;
557+
}
558+
}
559+
switch (vec_size) {
560+
case 8:
561+
LAUNCH_MAXPOOL_CHANNEL_LAST_VEC(
562+
scalar_t,
563+
8,
564+
num_wg,
565+
wg_size,
566+
queue,
567+
output,
568+
indices,
569+
input,
570+
numBatch,
571+
numPlane,
572+
inputSizeH,
573+
inputSizeW,
574+
outputSizeH,
575+
outputSizeW,
576+
kH,
577+
kW,
578+
dH,
579+
dW,
580+
padH,
581+
padW,
582+
dilationH,
583+
dilationW,
584+
stride);
585+
return;
586+
case 4:
587+
LAUNCH_MAXPOOL_CHANNEL_LAST_VEC(
588+
scalar_t,
589+
4,
590+
num_wg,
591+
wg_size,
592+
queue,
593+
output,
594+
indices,
595+
input,
596+
numBatch,
597+
numPlane,
598+
inputSizeH,
599+
inputSizeW,
600+
outputSizeH,
601+
outputSizeW,
602+
kH,
603+
kW,
604+
dH,
605+
dW,
606+
padH,
607+
padW,
608+
dilationH,
609+
dilationW,
610+
stride);
611+
return;
612+
case 2:
613+
LAUNCH_MAXPOOL_CHANNEL_LAST_VEC(
614+
scalar_t,
615+
2,
616+
num_wg,
617+
wg_size,
618+
queue,
619+
output,
620+
indices,
621+
input,
622+
numBatch,
623+
numPlane,
624+
inputSizeH,
625+
inputSizeW,
626+
outputSizeH,
627+
outputSizeW,
628+
kH,
629+
kW,
630+
dH,
631+
dW,
632+
padH,
633+
padW,
634+
dilationH,
635+
dilationW,
636+
stride);
637+
return;
638+
default:
639+
break;
640+
};
641+
}
642+
using KernelClass = MaxPool2dKernelFunctor<scalar_t, is_channels_last>;
643+
376644
BatchKernelConfig cfg = BatchKernelConfig::make_config<KernelClass>(
377645
1, outputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive);
378646
auto kfn = KernelClass(
@@ -704,6 +972,6 @@ void max_pool2d_with_indices_backward_kernel(
704972
}
705973

706974
} // namespace at::native::xpu
707-
975+
#undef LAUNCH_MAXPOOL_CHANNEL_LAST_VEC
708976
#pragma GCC diagnostic pop
709977
#pragma clang diagnostic pop

0 commit comments

Comments
 (0)