File tree Expand file tree Collapse file tree 1 file changed +18
-0
lines changed
aten/src/ATen/native/cuda Expand file tree Collapse file tree 1 file changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -831,6 +831,23 @@ struct ReduceOp {
831831 } else {
832832 index_t input_offset = threadIdx .y ;
833833 index_t step = blockDim .y ;
834+ #ifdef USE_ROCM // Prefetch loads to better hide their latency
835+ #define PRFCH 4
836+ for (; input_offset < config.ctas_per_output ; input_offset += step*PRFCH) {
837+ arg_vec_t next[PRFCH];
838+ #pragma unroll
839+ for (int u = 0 ; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output ); u++) {
840+ index_t idx = config.staging_memory_offset (input_offset + u*step);
841+ next[u] = reduce_buffer[idx];
842+ }
843+ for (int u = 0 ; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output ); u++) {
844+ #pragma unroll
845+ for (int i = 0 ; i < output_vec_size; i++) {
846+ value[i] = ops.combine (value[i], next[u][i]);
847+ }
848+ }
849+ }
850+ #else
834851 for (; input_offset < config.ctas_per_output ; input_offset += step) {
835852 index_t idx = config.staging_memory_offset (input_offset);
836853 arg_vec_t next = reduce_buffer[idx];
@@ -839,6 +856,7 @@ struct ReduceOp {
839856 value[i] = ops.combine (value[i], next[i]);
840857 }
841858 }
859+ #endif
842860 }
843861 value = block_y_reduce<output_vec_size>(value, shared_memory);
844862 if (config.should_block_x_reduce ()) {
You can’t perform that action at this time.
0 commit comments