Skip to content

Commit 10cbfa3

Browse files
authored
[ROCm] Unroll loads in global_reduce (#2554)
cherry-pick of pytorch#161181 Fixes SWDEV-545710
1 parent c00d48c commit 10cbfa3

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,23 @@ struct ReduceOp {
831831
} else {
832832
index_t input_offset = threadIdx.y;
833833
index_t step = blockDim.y;
834+
#ifdef USE_ROCM // Prefetch loads to better hide their latency
835+
#define PRFCH 4
836+
for (; input_offset < config.ctas_per_output; input_offset += step*PRFCH) {
837+
arg_vec_t next[PRFCH];
838+
#pragma unroll
839+
for (int u = 0; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output); u++) {
840+
index_t idx = config.staging_memory_offset(input_offset + u*step);
841+
next[u] = reduce_buffer[idx];
842+
}
843+
for (int u = 0; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output); u++) {
844+
#pragma unroll
845+
for (int i = 0; i < output_vec_size; i++) {
846+
value[i] = ops.combine(value[i], next[u][i]);
847+
}
848+
}
849+
}
850+
#else
834851
for (; input_offset < config.ctas_per_output; input_offset += step) {
835852
index_t idx = config.staging_memory_offset(input_offset);
836853
arg_vec_t next = reduce_buffer[idx];
@@ -839,6 +856,7 @@ struct ReduceOp {
839856
value[i] = ops.combine(value[i], next[i]);
840857
}
841858
}
859+
#endif
842860
}
843861
value = block_y_reduce<output_vec_size>(value, shared_memory);
844862
if (config.should_block_x_reduce()) {

0 commit comments

Comments
 (0)