Skip to content

Commit e2d141d

Browse files
pytorchbotngimel
andauthored
set thread_work_size to 4 for unrolled kernel (pytorch#154541)
set thread_work_size to 4 for unrolled kernel (pytorch#152396) Previous PRs enabling 8-vectorization inadvertently regressed unrolled kernel perf. Pull Request resolved: pytorch#152396 Approved by: https://github.com/BoyuanFeng, https://github.com/msaroufim, https://github.com/malfet, https://github.com/Aidyn-A, https://github.com/atalman (cherry picked from commit adebb8b) Co-authored-by: Natalia Gimelshein <[email protected]>
1 parent 1214198 commit e2d141d

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ constexpr auto elems_per_thread(){
8383
}
8484
#endif
8585

86+
87+
//thread work size of 8 regresses the perf of elementwise kernel on cuda
88+
//this doesn't change ROCm behavior as thread_work_size is already 4 on ROCm
89+
constexpr int elementwise_thread_work_size() {return 4;}
90+
constexpr int elementwise_block_work_size() {
91+
return elementwise_thread_work_size() * num_threads();
92+
}
93+
8694
template <int io_sizes>
8795
constexpr auto io_block_work_size() {
8896
return num_threads() * elems_per_thread<io_sizes>();
@@ -336,9 +344,10 @@ static inline void launch_unrolled_kernel(
336344
loader_t l,
337345
storer_t s) {
338346
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
339-
int64_t grid = (N + block_work_size() - 1) / block_work_size();
347+
348+
int64_t grid = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size();
340349
auto stream = at::cuda::getCurrentCUDAStream();
341-
unrolled_elementwise_kernel<func_t, array_t, thread_work_size()>
350+
unrolled_elementwise_kernel<func_t, array_t, elementwise_thread_work_size()>
342351
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
343352
C10_CUDA_KERNEL_LAUNCH_CHECK();
344353
}

0 commit comments

Comments
 (0)