-
Notifications
You must be signed in to change notification settings - Fork 49
add vectorization path on maxpool backward channel last #1907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
ac60deb
76a5583
18eb934
cefb88a
25d2766
357beb4
ead7971
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -11,6 +11,7 @@ | |||||||
|
||||||||
#include <ATen/native/xpu/sycl/Atomics.h> | ||||||||
#include <ATen/native/xpu/sycl/BatchKernel.h> | ||||||||
#include <ATen/native/xpu/sycl/MemoryAccess.h> | ||||||||
#include <ATen/native/xpu/sycl/NumericLimits.h> | ||||||||
#include <comm/Runtime.h> | ||||||||
#include <comm/SYCLHelpers.h> | ||||||||
|
@@ -251,30 +252,33 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { | |||||||
int pwstart = | ||||||||
p_start(inputW, pad_w_, kernel_w_, dilation_w_, stride_w_); | ||||||||
int pwend = p_end(inputW, pad_w_, gradOutputSizeW_, stride_w_); | ||||||||
scalar_t grad = 0; | ||||||||
if constexpr (is_channels_last) { | ||||||||
int offset = batch * out_n_stride_ + plane; | ||||||||
for (int ph = phstart; ph < phend; ++ph) { | ||||||||
for (int pw = pwstart; pw < pwend; ++pw) { | ||||||||
if (indices_[offset + (ph * gradOutputSizeW_ + pw) * numPlane_] == | ||||||||
input_hw_index) { | ||||||||
gradInput_[inputIndex] += static_cast<scalar_t>( | ||||||||
grad += static_cast<scalar_t>( | ||||||||
gradOutput_ | ||||||||
[offset + (ph * gradOutputSizeW_ + pw) * numPlane_]); | ||||||||
} | ||||||||
} | ||||||||
} | ||||||||
} else { | ||||||||
} | ||||||||
else { | ||||||||
int offset = batch * out_n_stride_ + plane * out_cf_c_stride_; | ||||||||
for (int ph = phstart; ph < phend; ++ph) { | ||||||||
for (int pw = pwstart; pw < pwend; ++pw) { | ||||||||
if (indices_[offset + ph * gradOutputSizeW_ + pw] == | ||||||||
input_hw_index) { | ||||||||
gradInput_[inputIndex] += static_cast<scalar_t>( | ||||||||
grad += static_cast<scalar_t>( | ||||||||
gradOutput_[offset + ph * gradOutputSizeW_ + pw]); | ||||||||
} | ||||||||
} | ||||||||
} | ||||||||
} | ||||||||
gradInput_[inputIndex] = grad; | ||||||||
} | ||||||||
} while (cfg_.next(item, desc)); | ||||||||
} | ||||||||
|
@@ -349,6 +353,122 @@ struct MaxPool2dBackwardDeterministicKernelFunctor { | |||||||
BatchKernelConfig cfg_; | ||||||||
}; | ||||||||
|
||||||||
template <typename scalar_t, typename vec_t, int vec_size> | ||||||||
struct MaxPool2dBackwardChannelLastVec { | ||||||||
void operator()(sycl::nd_item<1> item) const { | ||||||||
for (auto inputIndex = item.get_global_linear_id(); | ||||||||
inputIndex < gradInputSize_ / vec_size; | ||||||||
inputIndex += item.get_local_range(0) * item.get_group_range(0)) { | ||||||||
int batch = inputIndex / (in_n_stride_ / vec_size); | ||||||||
int plane; | ||||||||
int64_t input_hw_index; | ||||||||
|
||||||||
plane = inputIndex % (numPlane_ / vec_size); | ||||||||
input_hw_index = | ||||||||
((inputIndex % in_n_stride_) - plane) / (numPlane_ / vec_size); | ||||||||
|
||||||||
int inputW = input_hw_index % gradInputSizeW_; | ||||||||
int inputH = input_hw_index / gradInputSizeW_; | ||||||||
int phstart = p_start(inputH, pad_h_, kernel_h_, dilation_h_, stride_h_); | ||||||||
int phend = p_end(inputH, pad_h_, gradOutputSizeH_, stride_h_); | ||||||||
int pwstart = p_start(inputW, pad_w_, kernel_w_, dilation_w_, stride_w_); | ||||||||
int pwend = p_end(inputW, pad_w_, gradOutputSizeW_, stride_w_); | ||||||||
scalar_t grad = 0; | ||||||||
chunhuanMeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
int64_t load_offset, store_offset; | ||||||||
store_offset = inputIndex; | ||||||||
vec_t grad_vec; | ||||||||
#pragma unroll | ||||||||
for (int i = 0; i < vec_size; i++) { | ||||||||
grad_vec[i] = 0; | ||||||||
} | ||||||||
|
||||||||
int offset = batch * (out_n_stride_ / vec_size) + plane; | ||||||||
for (int ph = phstart; ph < phend; ++ph) { | ||||||||
for (int pw = pwstart; pw < pwend; ++pw) { | ||||||||
load_offset = | ||||||||
offset + (ph * gradOutputSizeW_ + pw) * (numPlane_ / vec_size); | ||||||||
vec_t gout_val_vec = gradOutput_[load_offset]; | ||||||||
#pragma unroll | ||||||||
for (int i = 0; i < vec_size; i++) { | ||||||||
if (indices_[load_offset * vec_size + i] == input_hw_index) { | ||||||||
grad_vec[i] = static_cast<scalar_t>(grad_vec[i]) + | ||||||||
static_cast<scalar_t>(gout_val_vec[i]); | ||||||||
Comment on lines
+392
to
+393
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cast
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||
} | ||||||||
} | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
gradInput_[store_offset] = grad_vec; | ||||||||
} | ||||||||
} | ||||||||
MaxPool2dBackwardChannelLastVec( | ||||||||
vec_t* gradInput, | ||||||||
const vec_t* gradOutput, | ||||||||
const int64_t* indices, | ||||||||
int numPlane, | ||||||||
int gradInputSizeH, | ||||||||
int gradInputSizeW, | ||||||||
int gradOutputSizeH, | ||||||||
int gradOutputSizeW, | ||||||||
int64_t gradInputSize, | ||||||||
int out_cf_c_stride, | ||||||||
int in_cf_c_stride, | ||||||||
int out_n_stride, | ||||||||
int in_n_stride, | ||||||||
int kernel_h, | ||||||||
int kernel_w, | ||||||||
int stride_h, | ||||||||
int stride_w, | ||||||||
int pad_h, | ||||||||
int pad_w, | ||||||||
int dilation_h, | ||||||||
int dilation_w) | ||||||||
: gradInput_(gradInput), | ||||||||
gradOutput_(gradOutput), | ||||||||
indices_(indices), | ||||||||
numPlane_(numPlane), | ||||||||
gradInputSizeH_(gradInputSizeH), | ||||||||
gradInputSizeW_(gradInputSizeW), | ||||||||
gradOutputSizeH_(gradOutputSizeH), | ||||||||
gradOutputSizeW_(gradOutputSizeW), | ||||||||
gradInputSize_(gradInputSize), | ||||||||
out_cf_c_stride_(out_cf_c_stride), | ||||||||
in_cf_c_stride_(in_cf_c_stride), | ||||||||
out_n_stride_(out_n_stride), | ||||||||
in_n_stride_(in_n_stride), | ||||||||
kernel_h_(kernel_h), | ||||||||
kernel_w_(kernel_w), | ||||||||
stride_h_(stride_h), | ||||||||
stride_w_(stride_w), | ||||||||
pad_h_(pad_h), | ||||||||
pad_w_(pad_w), | ||||||||
dilation_h_(dilation_h), | ||||||||
dilation_w_(dilation_w) {} | ||||||||
|
||||||||
private: | ||||||||
vec_t* gradInput_; | ||||||||
const vec_t* gradOutput_; | ||||||||
const int64_t* indices_; | ||||||||
int numPlane_; | ||||||||
int gradInputSizeH_; | ||||||||
int gradInputSizeW_; | ||||||||
int gradOutputSizeH_; | ||||||||
int gradOutputSizeW_; | ||||||||
int64_t gradInputSize_; | ||||||||
int out_cf_c_stride_; | ||||||||
int in_cf_c_stride_; | ||||||||
int out_n_stride_; | ||||||||
int in_n_stride_; | ||||||||
int kernel_h_; | ||||||||
int kernel_w_; | ||||||||
int stride_h_; | ||||||||
int stride_w_; | ||||||||
int pad_h_; | ||||||||
int pad_w_; | ||||||||
int dilation_h_; | ||||||||
int dilation_w_; | ||||||||
}; | ||||||||
|
||||||||
template <typename scalar_t, bool is_channels_last> | ||||||||
void launch_max_pool2d_kernel( | ||||||||
scalar_t* output, | ||||||||
|
@@ -397,6 +517,62 @@ void launch_max_pool2d_kernel( | |||||||
sycl_kernel_submit(cfg.global_size(), cfg.group_size(), queue, kfn); | ||||||||
} | ||||||||
|
||||||||
#define LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( \ | ||||||||
scalar_t, \ | ||||||||
vec_size, \ | ||||||||
num_wg, \ | ||||||||
wg_size, \ | ||||||||
queue, \ | ||||||||
gradInput, \ | ||||||||
gradOutput, \ | ||||||||
indices, \ | ||||||||
numPlane, \ | ||||||||
gradInputSizeH, \ | ||||||||
gradInputSizeW, \ | ||||||||
gradOutputSizeH, \ | ||||||||
gradOutputSizeW, \ | ||||||||
gradInputSize, \ | ||||||||
out_cf_c_stride, \ | ||||||||
in_cf_c_stride, \ | ||||||||
out_n_stride, \ | ||||||||
in_n_stride, \ | ||||||||
kernel_h, \ | ||||||||
kernel_w, \ | ||||||||
stride_h, \ | ||||||||
stride_w, \ | ||||||||
pad_h, \ | ||||||||
pad_w, \ | ||||||||
dilation_h, \ | ||||||||
dilation_w) \ | ||||||||
{ \ | ||||||||
using vec_t = memory::aligned_vector<scalar_t, vec_size>; \ | ||||||||
const vec_t* grad_output_vec = reinterpret_cast<const vec_t*>(gradOutput); \ | ||||||||
vec_t* grad_input_vec = reinterpret_cast<vec_t*>(gradInput); \ | ||||||||
auto kfn = MaxPool2dBackwardChannelLastVec<scalar_t, vec_t, vec_size>( \ | ||||||||
grad_input_vec, \ | ||||||||
grad_output_vec, \ | ||||||||
indices, \ | ||||||||
numPlane, \ | ||||||||
gradInputSizeH, \ | ||||||||
gradInputSizeW, \ | ||||||||
gradOutputSizeH, \ | ||||||||
gradOutputSizeW, \ | ||||||||
gradInputSize, \ | ||||||||
out_cf_c_stride, \ | ||||||||
in_cf_c_stride, \ | ||||||||
out_n_stride, \ | ||||||||
in_n_stride, \ | ||||||||
kernel_h, \ | ||||||||
kernel_w, \ | ||||||||
stride_h, \ | ||||||||
stride_w, \ | ||||||||
pad_h, \ | ||||||||
pad_w, \ | ||||||||
dilation_h, \ | ||||||||
dilation_w); \ | ||||||||
sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \ | ||||||||
} | ||||||||
|
||||||||
template <typename scalar_t, bool is_channels_last> | ||||||||
void launch_max_pool2d_backward_kernel( | ||||||||
scalar_t* gradInput, | ||||||||
|
@@ -435,6 +611,119 @@ void launch_max_pool2d_backward_kernel( | |||||||
// with CUDA in alexnet To avoid future problem, we decided to always use | ||||||||
// deterministic path. | ||||||||
|
||||||||
|
||||||||
// int vec_size = 1; | ||||||||
// int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); | ||||||||
// int num_sub_wg; | ||||||||
// auto wg_size = syclDeviceMaxWorkGroupSize(); | ||||||||
// int64_t num_wg; | ||||||||
// if constexpr (is_channels_last) { | ||||||||
// for (vec_size = std::min( | ||||||||
// 8, memory::can_vectorize_up_to<scalar_t>((char*)gradOutput)); | ||||||||
// vec_size >= 1; | ||||||||
// vec_size /= 2) { | ||||||||
// if (numPlane % vec_size != 0) { | ||||||||
// continue; | ||||||||
// } | ||||||||
// num_sub_wg = gradInputSize / vec_size / syclMaxSubGroupSize(); | ||||||||
// if (2 * num_sub_wg > thread_slots) { | ||||||||
// int total_thread = gradInputSize / vec_size; | ||||||||
// num_wg = (total_thread + wg_size - 1) / wg_size; | ||||||||
// break; | ||||||||
// } | ||||||||
// } | ||||||||
// switch (vec_size) { | ||||||||
// case 8: | ||||||||
// LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( | ||||||||
// scalar_t, | ||||||||
// 8, | ||||||||
// num_wg, | ||||||||
// wg_size, | ||||||||
// queue, | ||||||||
// gradInput, | ||||||||
// gradOutput, | ||||||||
// indices, | ||||||||
// numPlane, | ||||||||
// gradInputSizeH, | ||||||||
// gradInputSizeW, | ||||||||
// gradOutputSizeH, | ||||||||
// gradOutputSizeW, | ||||||||
// gradInputSize, | ||||||||
// out_cf_c_stride, | ||||||||
// in_cf_c_stride, | ||||||||
// out_n_stride, | ||||||||
// in_n_stride, | ||||||||
// kernel_h, | ||||||||
// kernel_w, | ||||||||
// stride_h, | ||||||||
// stride_w, | ||||||||
// pad_h, | ||||||||
// pad_w, | ||||||||
// dilation_h, | ||||||||
// dilation_w); | ||||||||
// return; | ||||||||
// case 4: | ||||||||
// LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( | ||||||||
// scalar_t, | ||||||||
// 1, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The vec_size parameter should be 4, not 1, for the case 4 branch. This appears to be a copy-paste error that would prevent proper vectorization when vec_size is 4.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||
// num_wg, | ||||||||
// wg_size, | ||||||||
// queue, | ||||||||
// gradInput, | ||||||||
// gradOutput, | ||||||||
// indices, | ||||||||
// numPlane, | ||||||||
// gradInputSizeH, | ||||||||
// gradInputSizeW, | ||||||||
// gradOutputSizeH, | ||||||||
// gradOutputSizeW, | ||||||||
// gradInputSize, | ||||||||
// out_cf_c_stride, | ||||||||
// in_cf_c_stride, | ||||||||
// out_n_stride, | ||||||||
// in_n_stride, | ||||||||
// kernel_h, | ||||||||
// kernel_w, | ||||||||
// stride_h, | ||||||||
// stride_w, | ||||||||
// pad_h, | ||||||||
// pad_w, | ||||||||
// dilation_h, | ||||||||
// dilation_w); | ||||||||
// return; | ||||||||
// case 2: | ||||||||
// LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( | ||||||||
// scalar_t, | ||||||||
// 2, | ||||||||
// num_wg, | ||||||||
// wg_size, | ||||||||
// queue, | ||||||||
// gradInput, | ||||||||
// gradOutput, | ||||||||
// indices, | ||||||||
// numPlane, | ||||||||
// gradInputSizeH, | ||||||||
// gradInputSizeW, | ||||||||
// gradOutputSizeH, | ||||||||
// gradOutputSizeW, | ||||||||
// gradInputSize, | ||||||||
// out_cf_c_stride, | ||||||||
// in_cf_c_stride, | ||||||||
// out_n_stride, | ||||||||
// in_n_stride, | ||||||||
// kernel_h, | ||||||||
// kernel_w, | ||||||||
// stride_h, | ||||||||
// stride_w, | ||||||||
// pad_h, | ||||||||
// pad_w, | ||||||||
// dilation_h, | ||||||||
// dilation_w); | ||||||||
// return; | ||||||||
// default: | ||||||||
// break; | ||||||||
// }; | ||||||||
// } | ||||||||
using KernelClass = | ||||||||
MaxPool2dBackwardDeterministicKernelFunctor<scalar_t, is_channels_last>; | ||||||||
BatchKernelConfig cfg = BatchKernelConfig::make_config<KernelClass>( | ||||||||
|
Uh oh!
There was an error while loading. Please reload this page.