diff --git a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp index d94db11c9..17a62436b 100644 --- a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp +++ b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp @@ -11,7 +11,9 @@ #include #include +#include #include + #include #include @@ -151,6 +153,119 @@ struct MaxPool2dKernelFunctor { BatchKernelConfig cfg_; }; +template +struct MaxPool2dChannelLastVec { + void operator()(sycl::nd_item<1> item) const { + for (auto outputIndex = item.get_global_linear_id(); + outputIndex < numBatch_ * stride_ / vec_size; + outputIndex += item.get_local_range(0) * item.get_group_range(0)) { + int batch = outputIndex / (stride_ / vec_size); + int plane, outputH, outputW; + int64_t load_offset, store_offset; + plane = outputIndex % (numPlane_ / vec_size); + outputH = + outputIndex / (numPlane_ / vec_size) / outputSizeW_ % outputSizeH_; + outputW = outputIndex / (numPlane_ / vec_size) % outputSizeW_; + store_offset = outputIndex; + + vec_t maxVal_vec; +#pragma unroll + for (int i = 0; i < vec_size; i++) { + maxVal_vec[i] = at::numeric_limits::lower_bound(); + } + int64_t maxIndex[vec_size]; + for (int i = 0; i < vec_size; i++) { + maxIndex[i] = int64_t(-1); + } + int StartH = outputH * dH_ - padH_; + int StartW = outputW * dW_ - padW_; + int EndH = std::min(StartH + (kH_ - 1) * dilationH_ + 1, inputSizeH_); + int EndW = std::min(StartW + (kW_ - 1) * dilationW_ + 1, inputSizeW_); + while (StartH < 0) + StartH += dilationH_; + while (StartW < 0) + StartW += dilationW_; + for (int h = StartH; h < EndH; h += dilationH_) { + for (int w = StartW; w < EndW; w += dilationW_) { + load_offset = batch * inputSizeH_*inputSizeW_*numPlane_ / vec_size + plane + + h * inputSizeW_ * numPlane_ / vec_size + w * numPlane_ / vec_size; + vec_t val_vec = input_vec_[load_offset]; +#pragma unroll + for (int i = 0; i < vec_size; i++) { + if ((static_cast(val_vec[i]) > maxVal_vec[i]) || + at::_isnan(val_vec[i])) { + maxIndex[i] = h * inputSizeW_ + w; + maxVal_vec[i] = static_cast(val_vec[i]); + } + } + } + } +#pragma unroll + for (int i = 0; i < vec_size; i++) { + indices_[store_offset * vec_size + i] = maxIndex[i]; + } + output_vec_[store_offset] = maxVal_vec; + } + } + MaxPool2dChannelLastVec( + vec_t* output_vec, + int64_t* indices, + const vec_t* input_vec, + int numBatch, + int numPlane, + int inputSizeH, + int inputSizeW, + int outputSizeH, + int outputSizeW, + int kH, + int kW, + int dH, + int dW, + int padH, + int padW, + int dilationH, + int dilationW, + int stride) + : output_vec_(output_vec), + indices_(indices), + input_vec_(input_vec), + numBatch_(numBatch), + numPlane_(numPlane), + inputSizeH_(inputSizeH), + inputSizeW_(inputSizeW), + outputSizeH_(outputSizeH), + outputSizeW_(outputSizeW), + kH_(kH), + kW_(kW), + dH_(dH), + dW_(dW), + padH_(padH), + padW_(padW), + dilationH_(dilationH), + dilationW_(dilationW), + stride_(stride) {} + + private: + vec_t* output_vec_; + int64_t* indices_; + const vec_t* input_vec_; + int numBatch_; + int numPlane_; + int inputSizeH_; + int inputSizeW_; + int outputSizeH_; + int outputSizeW_; + int kH_; + int kW_; + int dH_; + int dW_; + int padH_; + int padW_; + int dilationH_; + int dilationW_; + int stride_; +}; + template struct MaxPool2dBackwardKernelFunctor { void operator()(sycl::nd_item<2> item) const { @@ -349,6 +464,56 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { BatchKernelConfig cfg_; }; +#define LAUNCH_MAXPOOL_CHANNEL_LAST_VEC( \ + scalar_t, \ + vec_size, \ + num_wg, \ + wg_size, \ + queue, \ + output, \ + indices, \ + input, \ + numBatch, \ + numPlane, \ + inputSizeH, \ + inputSizeW, \ + outputSizeH, \ + outputSizeW, \ + kH, \ + kW, \ + dH, \ + dW, \ + padH, \ + padW, \ + dilationH, \ + dilationW, \ + stride) \ + { \ + using vec_t = memory::aligned_vector; \ + vec_t* output_vec = reinterpret_cast(output); \ + const vec_t* input_vec = reinterpret_cast(input); \ + auto kfn = MaxPool2dChannelLastVec( \ + output_vec, \ + indices, \ + input_vec, \ + numBatch, \ + numPlane, \ + inputSizeH, \ + inputSizeW, \ + outputSizeH, \ + outputSizeW, \ + kH, \ + kW, \ + dH, \ + dW, \ + padH, \ + padW, \ + dilationH, \ + dilationW, \ + stride); \ + sycl_kernel_submit(num_wg * wg_size, wg_size, queue, kfn); \ + } + template void launch_max_pool2d_kernel( scalar_t* output, @@ -368,11 +533,114 @@ void launch_max_pool2d_kernel( int padW, int dilationH, int dilationW) { - using KernelClass = MaxPool2dKernelFunctor; - auto& queue = at::xpu::getCurrentSYCLQueue(); int outputSize = numBatch * numPlane * outputSizeH * outputSizeW; int stride = numPlane * outputSizeH * outputSizeW; + int vec_size = 1; + int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); + int num_sub_wg; + auto wg_size = syclDeviceMaxWorkGroupSize(); + int64_t num_wg; + if constexpr (is_channels_last) { + for (vec_size = + std::min(8, memory::can_vectorize_up_to((char*)input)); + vec_size >= 1; + vec_size /= 2) { + if (numPlane % vec_size != 0) { + continue; + } + num_sub_wg = outputSize / vec_size / syclMaxSubGroupSize(); + if (2 * num_sub_wg > thread_slots) { + int total_thread = outputSize / vec_size; + num_wg = (total_thread + wg_size - 1) / wg_size; + break; + } + } + switch (vec_size) { + case 8: + LAUNCH_MAXPOOL_CHANNEL_LAST_VEC( + scalar_t, + 8, + num_wg, + wg_size, + queue, + output, + indices, + input, + numBatch, + numPlane, + inputSizeH, + inputSizeW, + outputSizeH, + outputSizeW, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + stride); + return; + case 4: + LAUNCH_MAXPOOL_CHANNEL_LAST_VEC( + scalar_t, + 4, + num_wg, + wg_size, + queue, + output, + indices, + input, + numBatch, + numPlane, + inputSizeH, + inputSizeW, + outputSizeH, + outputSizeW, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + stride); + return; + case 2: + LAUNCH_MAXPOOL_CHANNEL_LAST_VEC( + scalar_t, + 2, + num_wg, + wg_size, + queue, + output, + indices, + input, + numBatch, + numPlane, + inputSizeH, + inputSizeW, + outputSizeH, + outputSizeW, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + stride); + return; + default: + break; + }; + } + using KernelClass = MaxPool2dKernelFunctor; + BatchKernelConfig cfg = BatchKernelConfig::make_config( 1, outputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); auto kfn = KernelClass( @@ -704,6 +972,6 @@ void max_pool2d_with_indices_backward_kernel( } } // namespace at::native::xpu - +#undef LAUNCH_MAXPOOL_CHANNEL_LAST_VEC #pragma GCC diagnostic pop #pragma clang diagnostic pop