Skip to content

Commit 12bf7d0

Browse files
doru1004jerrymannil
authored andcommitted
[ROCm] Limit number of values per thread for reductions on three dimensions (pytorch#159652)
In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high. Pull Request resolved: pytorch#159652 Approved by: https://github.com/jeffdaily
1 parent 9948289 commit 12bf7d0

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ struct ReduceConfig {
209209
int values_per_thread() const {
210210
return div_up(num_inputs, step_input);
211211
}
212+
213+
int mock_values_per_thread(int parallelism) {
214+
return div_up(num_inputs, step_input * parallelism);
215+
}
212216
};
213217

214218
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
@@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
11661170
else if (config.ctas_per_output < 16)
11671171
config.ctas_per_output = 1;
11681172
bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast);
1169-
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last)
1173+
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) {
11701174
config.ctas_per_output = 4;
1175+
int vpt = config.values_per_thread();
1176+
// Capping the number of values per thread to 2048 for now
1177+
// based on known use cases.
1178+
while (vpt >= 2048) {
1179+
config.ctas_per_output *= 2;
1180+
// Computes the new values per thread without side effects
1181+
vpt = config.mock_values_per_thread(config.ctas_per_output);
1182+
}
1183+
}
11711184
#endif
11721185
if (config.ctas_per_output > 1) {
11731186
config.input_mult[2] = config.split_input(config.ctas_per_output);

0 commit comments

Comments
 (0)