Skip to content

Commit 82bd728

Browse files
jerrymannilAMD AMD
authored andcommitted
[ROCm] [Normalization] Update block size (#2738)
cherry-pick of pytorch@9f82535 Fixes #SWDEV-561122
1 parent 330f52d commit 82bd728

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

aten/src/ATen/native/cuda/Normalization.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace at::native {
2323

2424
// The maximum number of threads in a block
2525
#if defined(USE_ROCM)
26-
constexpr int MAX_BLOCK_SIZE = 256;
26+
constexpr int MAX_BLOCK_SIZE = 1024;
2727
#else
2828
constexpr int MAX_BLOCK_SIZE = 512;
2929
#endif
@@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
3333
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
3434
static int getNumThreads(int nElem) {
3535
#if defined(USE_ROCM)
36-
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
36+
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
3737
#else
3838
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
3939
#endif

0 commit comments

Comments
 (0)