Skip to content

Commit 6e3be13

Browse files
jerrymannilpruthvistony
authored andcommitted
[ROCm] Use opportunistic fastatomics based on hueristics (pytorch#159430)
* Opportunistic fast atomics works better with small sizes, since there is more chance of lanes doing atomics on the same address Co-author: @amd-hhashemi Reproducer: ``` import time import torch x = torch.randn((1_632_960, 128), device='cuda', dtype=torch.float) ind = torch.randint(0, x.size(0), size=(5_079_670,), device='cuda') src = torch.randn((5_079_670, 128), device='cuda', dtype=torch.float) for _ in range(20): x.index_add_(0, ind, src) start_time = time.time() for i in range(100): x.index_add_(0, ind, src) torch.cuda.synchronize() end_time = time.time() mean_time = (end_time - start_time)/100 print(f"Avg time for index_add_: {mean_time * 1e6:.2f} us") ``` Perf numbers: ``` Before: Avg time for index_add_: 25652.16 us After: Avg time for index_add_: 2675.15 us ``` Pull Request resolved: pytorch#159430 Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily
1 parent d533e56 commit 6e3be13

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

aten/src/ATen/native/cuda/KernelUtils.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,14 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd(
282282
}
283283

284284
// not coalsced, so now let try to capture lane-matches...
285+
286+
if (numel > 16 /*<-hueristic threshold*/ * 64 ) {
287+
// well shucks, unlikely to capture same-dest atomics in a wave.
288+
// fall back to direct fastAtomic...
289+
fastAtomicAdd(self_ptr, index, numel, value, true);
290+
return;
291+
}
292+
285293
// __activemask() -- finds the set of threads in the warp that are about to perform atomicAdd
286294
// __match_any_sync() -- returns bit mask of the threads that have same dest addr
287295
auto mask = __match_any_sync(__activemask(), (int64_t)dst);

0 commit comments

Comments
 (0)