Skip to content

Commit 9d15d89

Browse files
jerrymannilglen-amd
andcommitted
[ROCm] Enable more parallelism for multi-dimensional reductions (#2291)
cherry-pick of pytorch@085f270 in rocm/pytorch:release/2.7 Co-authored-by: Doru Bercea, Glen Cao <[email protected]>
1 parent dae14f9 commit 9d15d89

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,13 +1115,19 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
11151115
int max_threads_per_mp =
11161116
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor;
11171117
#ifdef USE_ROCM
1118-
// Control the number of threadblocks by adjusting the maximum number of
1119-
// threads per multi-processor. These numbers better reflect the maximum
1120-
// theoretical achievable threads per MP for the reduction operation.
1121-
if (iter.ndim() == 1 || iter.ndim() == 3)
1122-
max_threads_per_mp = 512;
1123-
if (iter.ndim() == 2)
1124-
max_threads_per_mp = 256;
1118+
// If the grid consists of a single threadblock, do not change the max threads per
1119+
// MP value. This will increase the parallelism across the y dimension of the grid.
1120+
bool uses_a_single_block = config.grid().x == config.grid().y == config.grid().z == 1;
1121+
1122+
if (!uses_a_single_block) {
1123+
// Control the number of threadblocks by adjusting the maximum number of
1124+
// threads per multi-processor. These numbers better reflect the maximum
1125+
// theoretical achievable threads per MP for the reduction operation.
1126+
if (iter.ndim() == 1 || iter.ndim() == 3)
1127+
max_threads_per_mp = 512;
1128+
else if (iter.ndim() == 2)
1129+
max_threads_per_mp = 256;
1130+
}
11251131
#endif
11261132
const int blocks_per_sm = max_threads_per_mp / config.num_threads;
11271133
const int target_grid_size = num_mp * blocks_per_sm;

0 commit comments

Comments
 (0)