Skip to content

Commit f704d8f

Browse files
jerrymannilAMD AMD
authored andcommitted
[ROCm] Unroll loads in global_reduce (#2554)
cherry-pick of pytorch#161181 Fixes SWDEV-545710
1 parent 125803b commit f704d8f

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,23 @@ struct ReduceOp {
822822
} else {
823823
index_t input_offset = threadIdx.y;
824824
index_t step = blockDim.y;
825+
#ifdef USE_ROCM // Prefetch loads to better hide their latency
826+
#define PRFCH 4
827+
for (; input_offset < config.ctas_per_output; input_offset += step*PRFCH) {
828+
arg_vec_t next[PRFCH];
829+
#pragma unroll
830+
for (int u = 0; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output); u++) {
831+
index_t idx = config.staging_memory_offset(input_offset + u*step);
832+
next[u] = reduce_buffer[idx];
833+
}
834+
for (int u = 0; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output); u++) {
835+
#pragma unroll
836+
for (int i = 0; i < output_vec_size; i++) {
837+
value[i] = ops.combine(value[i], next[u][i]);
838+
}
839+
}
840+
}
841+
#else
825842
for (; input_offset < config.ctas_per_output; input_offset += step) {
826843
index_t idx = config.staging_memory_offset(input_offset);
827844
arg_vec_t next = reduce_buffer[idx];
@@ -830,6 +847,7 @@ struct ReduceOp {
830847
value[i] = ops.combine(value[i], next[i]);
831848
}
832849
}
850+
#endif
833851
}
834852
value = block_y_reduce<output_vec_size>(value, shared_memory);
835853
if (config.should_block_x_reduce()) {

0 commit comments

Comments
 (0)