diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 7d1c45e785b79..7cc71711d01d6 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -822,6 +822,23 @@ struct ReduceOp { } else { index_t input_offset = threadIdx.y; index_t step = blockDim.y; +#ifdef USE_ROCM // Prefetch loads to better hide their latency + #define PRFCH 4 + for (; input_offset < config.ctas_per_output; input_offset += step*PRFCH) { + arg_vec_t next[PRFCH]; + #pragma unroll + for (int u = 0; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output); u++) { + index_t idx = config.staging_memory_offset(input_offset + u*step); + next[u] = reduce_buffer[idx]; + } + for (int u = 0; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output); u++) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], next[u][i]); + } + } + } +#else for (; input_offset < config.ctas_per_output; input_offset += step) { index_t idx = config.staging_memory_offset(input_offset); arg_vec_t next = reduce_buffer[idx]; @@ -830,6 +847,7 @@ struct ReduceOp { value[i] = ops.combine(value[i], next[i]); } } +#endif } value = block_y_reduce(value, shared_memory); if (config.should_block_x_reduce()) {