From 9bef61331d923028b7c563a776d89818ba013aa8 Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Fri, 22 Aug 2025 09:55:48 -0700 Subject: [PATCH] [ROCm] Use fine-grain fence in reduction (#2553) cherry-pick of https://github.com/pytorch/pytorch/pull/160979 Less-performant fix until https://github.com/pytorch/pytorch/pull/161180 is finalized * The global reduction path in reduction kernel currently has two threadfence operation * The first threadfence is executed by all threads in all the blocks, whereas the second threadfence is only run by threads in a single block * For AMD gpus, threadfence is a heavy weight operation, esp. when run by all the threads in the system (due to cross-XCD synchronizations) * So using fine-grain fence gives significant performance boost for AMD gpus. * We do a release fence when threads write to reduce buffer in global memory; and then do a acquire fence when threads read from the reduce buffer Co-author: @amd-hhashemi, @jeffdaily **Reproducer**: ```import time import torch shapes = [(2, 896, 59, 91), ] dims = [(2, 3), ] for i, shape in enumerate(shapes): x = torch.randn(shape, device='cuda', dtype=torch.bfloat16) x = x.to(memory_format=torch.channels_last) for _ in range(20): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) torch.cuda.synchronize() start_evt = torch.cuda.Event(enable_timing=True) end_evt = torch.cuda.Event(enable_timing=True) start_evt.record() for _ in range(100): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) end_evt.record() torch.cuda.synchronize() print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us") ``` Fixes SWDEV-545710 --- aten/src/ATen/native/cuda/Reduce.cuh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 7cc71711d01d6..e67adff4a4d98 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -797,14 +797,23 @@ struct ReduceOp { if (should_store) { index_t offset = config.staging_memory_offset(blockIdx.y); reduce_buffer[offset] = value; +#ifdef USE_ROCM + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "agent"); // make sure writes are globally visible +#endif } +#ifndef USE_ROCM __threadfence(); // make sure writes are globally visible +#endif __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done bool is_last_block_done = mark_block_finished(); if (is_last_block_done) { +#ifdef USE_ROCM + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "agent"); // complete the acquire pattern after release +#else __threadfence(); // complete the acquire pattern after atomic +#endif for (auto &v : value) { v = ident; }