Skip to content

Commit 6ec3241

Browse files
jerrymannilAMD AMD
authored andcommitted
[ROCm] Use fine-grain fence in reduction (#2553)
cherry-pick of pytorch#160979 Less-performant fix until pytorch#161180 is finalized * The global reduction path in reduction kernel currently has two threadfence operation * The first threadfence is executed by all threads in all the blocks, whereas the second threadfence is only run by threads in a single block * For AMD gpus, threadfence is a heavy weight operation, esp. when run by all the threads in the system (due to cross-XCD synchronizations) * So using fine-grain fence gives significant performance boost for AMD gpus. * We do a release fence when threads write to reduce buffer in global memory; and then do a acquire fence when threads read from the reduce buffer Co-author: @amd-hhashemi, @jeffdaily **Reproducer**: ```import time import torch shapes = [(2, 896, 59, 91), ] dims = [(2, 3), ] for i, shape in enumerate(shapes): x = torch.randn(shape, device='cuda', dtype=torch.bfloat16) x = x.to(memory_format=torch.channels_last) for _ in range(20): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) torch.cuda.synchronize() start_evt = torch.cuda.Event(enable_timing=True) end_evt = torch.cuda.Event(enable_timing=True) start_evt.record() for _ in range(100): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) end_evt.record() torch.cuda.synchronize() print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us") ``` Fixes SWDEV-545710
1 parent fd4b1e7 commit 6ec3241

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,14 +797,23 @@ struct ReduceOp {
797797
if (should_store) {
798798
index_t offset = config.staging_memory_offset(blockIdx.y);
799799
reduce_buffer[offset] = value;
800+
#ifdef USE_ROCM
801+
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "agent"); // make sure writes are globally visible
802+
#endif
800803
}
801804

805+
#ifndef USE_ROCM
802806
__threadfence(); // make sure writes are globally visible
807+
#endif
803808
__syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
804809
bool is_last_block_done = mark_block_finished();
805810

806811
if (is_last_block_done) {
812+
#ifdef USE_ROCM
813+
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "agent"); // complete the acquire pattern after release
814+
#else
807815
__threadfence(); // complete the acquire pattern after atomic
816+
#endif
808817
for (auto &v : value) {
809818
v = ident;
810819
}

0 commit comments

Comments
 (0)