Skip to content
276 changes: 273 additions & 3 deletions src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <ATen/native/xpu/sycl/Atomics.h>
#include <ATen/native/xpu/sycl/BatchKernel.h>
#include <ATen/native/xpu/sycl/MemoryAccess.h>
#include <ATen/native/xpu/sycl/NumericLimits.h>
#include <comm/Runtime.h>
#include <comm/SYCLHelpers.h>
Expand Down Expand Up @@ -251,13 +252,14 @@ struct MaxPool2dBackwardDeterministicKernelFunctor {
int pwstart =
p_start(inputW, pad_w_, kernel_w_, dilation_w_, stride_w_);
int pwend = p_end(inputW, pad_w_, gradOutputSizeW_, stride_w_);
scalar_t grad = 0;
if constexpr (is_channels_last) {
int offset = batch * out_n_stride_ + plane;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (indices_[offset + (ph * gradOutputSizeW_ + pw) * numPlane_] ==
input_hw_index) {
gradInput_[inputIndex] += static_cast<scalar_t>(
grad += static_cast<scalar_t>(
gradOutput_
[offset + (ph * gradOutputSizeW_ + pw) * numPlane_]);
}
Expand All @@ -269,12 +271,13 @@ struct MaxPool2dBackwardDeterministicKernelFunctor {
for (int pw = pwstart; pw < pwend; ++pw) {
if (indices_[offset + ph * gradOutputSizeW_ + pw] ==
input_hw_index) {
gradInput_[inputIndex] += static_cast<scalar_t>(
grad += static_cast<scalar_t>(
gradOutput_[offset + ph * gradOutputSizeW_ + pw]);
}
}
}
}
gradInput_[inputIndex] = grad;
}
} while (cfg_.next(item, desc));
}
Expand Down Expand Up @@ -349,6 +352,116 @@ struct MaxPool2dBackwardDeterministicKernelFunctor {
BatchKernelConfig cfg_;
};

template <typename scalar_t, typename vec_t, int vec_size>
struct MaxPool2dBackwardChannelLastVec {
void operator()(sycl::nd_item<1> item) const {
for (auto inputIndex = item.get_global_linear_id();
inputIndex < gradInputSize_ / vec_size;
inputIndex += item.get_local_range(0) * item.get_group_range(0)) {
int batch = inputIndex / (in_n_stride_ / vec_size);
int plane;
int64_t input_hw_index;

plane = inputIndex % (numPlane_ / vec_size);

input_hw_index = ((inputIndex % (in_n_stride_ / vec_size)) - plane) /
(numPlane_ / vec_size);

int inputW = input_hw_index % gradInputSizeW_;
int inputH = input_hw_index / gradInputSizeW_;
int phstart = p_start(inputH, pad_h_, kernel_h_, dilation_h_, stride_h_);
int phend = p_end(inputH, pad_h_, gradOutputSizeH_, stride_h_);
int pwstart = p_start(inputW, pad_w_, kernel_w_, dilation_w_, stride_w_);
int pwend = p_end(inputW, pad_w_, gradOutputSizeW_, stride_w_);
vec_t grad_vec;
#pragma unroll
for (int i = 0; i < vec_size; i++) {
grad_vec[i] = 0;
}

int offset = batch * out_n_stride_ / vec_size + plane;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int load_offset = offset +
ph * gradOutputSizeW_ * numPlane_ / vec_size +
pw * numPlane_ / vec_size;
vec_t gout_val_vec = gradOutput_[load_offset];
#pragma unroll
for (int i = 0; i < vec_size; i++) {
if (indices_[load_offset * vec_size + i] == input_hw_index) {
grad_vec[i] = static_cast<scalar_t>(grad_vec[i]) +
static_cast<scalar_t>(gout_val_vec[i]);
}
}
}
}

gradInput_[inputIndex] = grad_vec;
}
}
MaxPool2dBackwardChannelLastVec(
vec_t* gradInput,
const vec_t* gradOutput,
const int64_t* indices,
int numPlane,
int gradInputSizeH,
int gradInputSizeW,
int gradOutputSizeH,
int gradOutputSizeW,
int64_t gradInputSize,
int out_n_stride,
int in_n_stride,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w)
: gradInput_(gradInput),
gradOutput_(gradOutput),
indices_(indices),
numPlane_(numPlane),
gradInputSizeH_(gradInputSizeH),
gradInputSizeW_(gradInputSizeW),
gradOutputSizeH_(gradOutputSizeH),
gradOutputSizeW_(gradOutputSizeW),
gradInputSize_(gradInputSize),
out_n_stride_(out_n_stride),
in_n_stride_(in_n_stride),
kernel_h_(kernel_h),
kernel_w_(kernel_w),
stride_h_(stride_h),
stride_w_(stride_w),
pad_h_(pad_h),
pad_w_(pad_w),
dilation_h_(dilation_h),
dilation_w_(dilation_w) {}

private:
vec_t* gradInput_;
const vec_t* gradOutput_;
const int64_t* indices_;
int numPlane_;
int gradInputSizeH_;
int gradInputSizeW_;
int gradOutputSizeH_;
int gradOutputSizeW_;
int64_t gradInputSize_;
int out_n_stride_;
int in_n_stride_;
int kernel_h_;
int kernel_w_;
int stride_h_;
int stride_w_;
int pad_h_;
int pad_w_;
int dilation_h_;
int dilation_w_;
};


template <typename scalar_t, bool is_channels_last>
void launch_max_pool2d_kernel(
scalar_t* output,
Expand Down Expand Up @@ -397,6 +510,58 @@ void launch_max_pool2d_kernel(
sycl_kernel_submit(cfg.global_size(), cfg.group_size(), queue, kfn);
}

#define LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( \
scalar_t, \
vec_size, \
num_wg, \
wg_size, \
queue, \
gradInput, \
gradOutput, \
indices, \
numPlane, \
gradInputSizeH, \
gradInputSizeW, \
gradOutputSizeH, \
gradOutputSizeW, \
gradInputSize, \
out_n_stride, \
in_n_stride, \
kernel_h, \
kernel_w, \
stride_h, \
stride_w, \
pad_h, \
pad_w, \
dilation_h, \
dilation_w) \
{ \
using vec_t = memory::aligned_vector<scalar_t, vec_size>; \
const vec_t* grad_output_vec = reinterpret_cast<const vec_t*>(gradOutput); \
vec_t* grad_input_vec = reinterpret_cast<vec_t*>(gradInput); \
auto kfn = MaxPool2dBackwardChannelLastVec<scalar_t, vec_t, vec_size>( \
grad_input_vec, \
grad_output_vec, \
indices, \
numPlane, \
gradInputSizeH, \
gradInputSizeW, \
gradOutputSizeH, \
gradOutputSizeW, \
gradInputSize, \
out_n_stride, \
in_n_stride, \
kernel_h, \
kernel_w, \
stride_h, \
stride_w, \
pad_h, \
pad_w, \
dilation_h, \
dilation_w); \
sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \
}

template <typename scalar_t, bool is_channels_last>
void launch_max_pool2d_backward_kernel(
scalar_t* gradInput,
Expand Down Expand Up @@ -435,6 +600,112 @@ void launch_max_pool2d_backward_kernel(
// with CUDA in alexnet To avoid future problem, we decided to always use
// deterministic path.

int vec_size = 1;
int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU();
int num_sub_wg;
auto wg_size = syclDeviceMaxWorkGroupSize();
int64_t num_wg;
if constexpr (is_channels_last) {
for (vec_size = std::min(
8, memory::can_vectorize_up_to<scalar_t>((char*)gradOutput));
vec_size >= 1;
vec_size /= 2) {
if (numPlane % vec_size != 0) {
continue;
}
num_sub_wg = gradInputSize / vec_size / syclMaxSubGroupSize();
if (2 * num_sub_wg > thread_slots) {
int total_thread = gradInputSize / vec_size;
num_wg = (total_thread + wg_size - 1) / wg_size;
break;
}
}
switch (vec_size) {
case 8:
LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC(
scalar_t,
8,
num_wg,
wg_size,
queue,
gradInput,
gradOutput,
indices,
numPlane,
gradInputSizeH,
gradInputSizeW,
gradOutputSizeH,
gradOutputSizeW,
gradInputSize,
out_n_stride,
in_n_stride,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w);
return;
case 4:
LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC(
scalar_t,
4,
num_wg,
wg_size,
queue,
gradInput,
gradOutput,
indices,
numPlane,
gradInputSizeH,
gradInputSizeW,
gradOutputSizeH,
gradOutputSizeW,
gradInputSize,
out_n_stride,
in_n_stride,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w);
return;
case 2:
LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC(
scalar_t,
2,
num_wg,
wg_size,
queue,
gradInput,
gradOutput,
indices,
numPlane,
gradInputSizeH,
gradInputSizeW,
gradOutputSizeH,
gradOutputSizeW,
gradInputSize,
out_n_stride,
in_n_stride,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w);
return;
default:
break;
};
}
using KernelClass =
MaxPool2dBackwardDeterministicKernelFunctor<scalar_t, is_channels_last>;
BatchKernelConfig cfg = BatchKernelConfig::make_config<KernelClass>(
Expand Down Expand Up @@ -647,7 +918,6 @@ void max_pool2d_with_indices_backward_kernel(
inputHeight, kH, padH, dH, dilationH, ceil_mode);
int64_t outputWidth = pooling_output_shape<int64_t>(
inputWidth, kW, padW, dW, dilationW, ceil_mode);

AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
Expand Down