File tree Expand file tree Collapse file tree 1 file changed +18
-0
lines changed
aten/src/ATen/native/cuda Expand file tree Collapse file tree 1 file changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -822,6 +822,23 @@ struct ReduceOp {
822822 } else {
823823 index_t input_offset = threadIdx .y ;
824824 index_t step = blockDim .y ;
825+ #ifdef USE_ROCM // Prefetch loads to better hide their latency
826+ #define PRFCH 4
827+ for (; input_offset < config.ctas_per_output ; input_offset += step*PRFCH) {
828+ arg_vec_t next[PRFCH];
829+ #pragma unroll
830+ for (int u = 0 ; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output ); u++) {
831+ index_t idx = config.staging_memory_offset (input_offset + u*step);
832+ next[u] = reduce_buffer[idx];
833+ }
834+ for (int u = 0 ; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output ); u++) {
835+ #pragma unroll
836+ for (int i = 0 ; i < output_vec_size; i++) {
837+ value[i] = ops.combine (value[i], next[u][i]);
838+ }
839+ }
840+ }
841+ #else
825842 for (; input_offset < config.ctas_per_output ; input_offset += step) {
826843 index_t idx = config.staging_memory_offset (input_offset);
827844 arg_vec_t next = reduce_buffer[idx];
@@ -830,6 +847,7 @@ struct ReduceOp {
830847 value[i] = ops.combine (value[i], next[i]);
831848 }
832849 }
850+ #endif
833851 }
834852 value = block_y_reduce<output_vec_size>(value, shared_memory);
835853 if (config.should_block_x_reduce ()) {
You can’t perform that action at this time.
0 commit comments