diff --git a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp index d94db11c9..a121eb7be 100644 --- a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp +++ b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -251,13 +252,14 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { int pwstart = p_start(inputW, pad_w_, kernel_w_, dilation_w_, stride_w_); int pwend = p_end(inputW, pad_w_, gradOutputSizeW_, stride_w_); + scalar_t grad = 0; if constexpr (is_channels_last) { int offset = batch * out_n_stride_ + plane; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { if (indices_[offset + (ph * gradOutputSizeW_ + pw) * numPlane_] == input_hw_index) { - gradInput_[inputIndex] += static_cast( + grad += static_cast( gradOutput_ [offset + (ph * gradOutputSizeW_ + pw) * numPlane_]); } @@ -269,12 +271,13 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { for (int pw = pwstart; pw < pwend; ++pw) { if (indices_[offset + ph * gradOutputSizeW_ + pw] == input_hw_index) { - gradInput_[inputIndex] += static_cast( + grad += static_cast( gradOutput_[offset + ph * gradOutputSizeW_ + pw]); } } } } + gradInput_[inputIndex] = grad; } } while (cfg_.next(item, desc)); } @@ -349,6 +352,116 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { BatchKernelConfig cfg_; }; +template +struct MaxPool2dBackwardChannelLastVec { + void operator()(sycl::nd_item<1> item) const { + for (auto inputIndex = item.get_global_linear_id(); + inputIndex < gradInputSize_ / vec_size; + inputIndex += item.get_local_range(0) * item.get_group_range(0)) { + int batch = inputIndex / (in_n_stride_ / vec_size); + int plane; + int64_t input_hw_index; + + plane = inputIndex % (numPlane_ / vec_size); + + input_hw_index = ((inputIndex % (in_n_stride_ / vec_size)) - plane) / + (numPlane_ / vec_size); + + int inputW = input_hw_index % gradInputSizeW_; + int inputH = input_hw_index / gradInputSizeW_; + int phstart = p_start(inputH, pad_h_, kernel_h_, dilation_h_, stride_h_); + int phend = p_end(inputH, pad_h_, gradOutputSizeH_, stride_h_); + int pwstart = p_start(inputW, pad_w_, kernel_w_, dilation_w_, stride_w_); + int pwend = p_end(inputW, pad_w_, gradOutputSizeW_, stride_w_); + vec_t grad_vec; +#pragma unroll + for (int i = 0; i < vec_size; i++) { + grad_vec[i] = 0; + } + + int offset = batch * out_n_stride_ / vec_size + plane; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + int load_offset = offset + + ph * gradOutputSizeW_ * numPlane_ / vec_size + + pw * numPlane_ / vec_size; + vec_t gout_val_vec = gradOutput_[load_offset]; +#pragma unroll + for (int i = 0; i < vec_size; i++) { + if (indices_[load_offset * vec_size + i] == input_hw_index) { + grad_vec[i] = static_cast(grad_vec[i]) + + static_cast(gout_val_vec[i]); + } + } + } + } + + gradInput_[inputIndex] = grad_vec; + } + } + MaxPool2dBackwardChannelLastVec( + vec_t* gradInput, + const vec_t* gradOutput, + const int64_t* indices, + int numPlane, + int gradInputSizeH, + int gradInputSizeW, + int gradOutputSizeH, + int gradOutputSizeW, + int64_t gradInputSize, + int out_n_stride, + int in_n_stride, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w) + : gradInput_(gradInput), + gradOutput_(gradOutput), + indices_(indices), + numPlane_(numPlane), + gradInputSizeH_(gradInputSizeH), + gradInputSizeW_(gradInputSizeW), + gradOutputSizeH_(gradOutputSizeH), + gradOutputSizeW_(gradOutputSizeW), + gradInputSize_(gradInputSize), + out_n_stride_(out_n_stride), + in_n_stride_(in_n_stride), + kernel_h_(kernel_h), + kernel_w_(kernel_w), + stride_h_(stride_h), + stride_w_(stride_w), + pad_h_(pad_h), + pad_w_(pad_w), + dilation_h_(dilation_h), + dilation_w_(dilation_w) {} + + private: + vec_t* gradInput_; + const vec_t* gradOutput_; + const int64_t* indices_; + int numPlane_; + int gradInputSizeH_; + int gradInputSizeW_; + int gradOutputSizeH_; + int gradOutputSizeW_; + int64_t gradInputSize_; + int out_n_stride_; + int in_n_stride_; + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int pad_h_; + int pad_w_; + int dilation_h_; + int dilation_w_; +}; + + template void launch_max_pool2d_kernel( scalar_t* output, @@ -397,6 +510,58 @@ void launch_max_pool2d_kernel( sycl_kernel_submit(cfg.global_size(), cfg.group_size(), queue, kfn); } +#define LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( \ + scalar_t, \ + vec_size, \ + num_wg, \ + wg_size, \ + queue, \ + gradInput, \ + gradOutput, \ + indices, \ + numPlane, \ + gradInputSizeH, \ + gradInputSizeW, \ + gradOutputSizeH, \ + gradOutputSizeW, \ + gradInputSize, \ + out_n_stride, \ + in_n_stride, \ + kernel_h, \ + kernel_w, \ + stride_h, \ + stride_w, \ + pad_h, \ + pad_w, \ + dilation_h, \ + dilation_w) \ + { \ + using vec_t = memory::aligned_vector; \ + const vec_t* grad_output_vec = reinterpret_cast(gradOutput); \ + vec_t* grad_input_vec = reinterpret_cast(gradInput); \ + auto kfn = MaxPool2dBackwardChannelLastVec( \ + grad_input_vec, \ + grad_output_vec, \ + indices, \ + numPlane, \ + gradInputSizeH, \ + gradInputSizeW, \ + gradOutputSizeH, \ + gradOutputSizeW, \ + gradInputSize, \ + out_n_stride, \ + in_n_stride, \ + kernel_h, \ + kernel_w, \ + stride_h, \ + stride_w, \ + pad_h, \ + pad_w, \ + dilation_h, \ + dilation_w); \ + sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \ + } + template void launch_max_pool2d_backward_kernel( scalar_t* gradInput, @@ -435,6 +600,112 @@ void launch_max_pool2d_backward_kernel( // with CUDA in alexnet To avoid future problem, we decided to always use // deterministic path. + int vec_size = 1; + int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); + int num_sub_wg; + auto wg_size = syclDeviceMaxWorkGroupSize(); + int64_t num_wg; + if constexpr (is_channels_last) { + for (vec_size = std::min( + 8, memory::can_vectorize_up_to((char*)gradOutput)); + vec_size >= 1; + vec_size /= 2) { + if (numPlane % vec_size != 0) { + continue; + } + num_sub_wg = gradInputSize / vec_size / syclMaxSubGroupSize(); + if (2 * num_sub_wg > thread_slots) { + int total_thread = gradInputSize / vec_size; + num_wg = (total_thread + wg_size - 1) / wg_size; + break; + } + } + switch (vec_size) { + case 8: + LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( + scalar_t, + 8, + num_wg, + wg_size, + queue, + gradInput, + gradOutput, + indices, + numPlane, + gradInputSizeH, + gradInputSizeW, + gradOutputSizeH, + gradOutputSizeW, + gradInputSize, + out_n_stride, + in_n_stride, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w); + return; + case 4: + LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( + scalar_t, + 4, + num_wg, + wg_size, + queue, + gradInput, + gradOutput, + indices, + numPlane, + gradInputSizeH, + gradInputSizeW, + gradOutputSizeH, + gradOutputSizeW, + gradInputSize, + out_n_stride, + in_n_stride, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w); + return; + case 2: + LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( + scalar_t, + 2, + num_wg, + wg_size, + queue, + gradInput, + gradOutput, + indices, + numPlane, + gradInputSizeH, + gradInputSizeW, + gradOutputSizeH, + gradOutputSizeW, + gradInputSize, + out_n_stride, + in_n_stride, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w); + return; + default: + break; + }; + } using KernelClass = MaxPool2dBackwardDeterministicKernelFunctor; BatchKernelConfig cfg = BatchKernelConfig::make_config( @@ -647,7 +918,6 @@ void max_pool2d_with_indices_backward_kernel( inputHeight, kH, padH, dH, dilationH, ceil_mode); int64_t outputWidth = pooling_output_shape( inputWidth, kW, padW, dW, dilationW, ceil_mode); - AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16,