Skip to content

Commit faae1f3

Browse files
authored
[release/2.7] [ROCm] Use opportunistic fastatomics based on heuristics (#2438)
* Merge of pytorch#159430 * Opportunistic fast atomics works better will small sizes, since there is more chance of lanes doing atomics on the same address Reproducer: ``` import time import torch x = torch.randn((1_632_960, 128), device='cuda', dtype=torch.float) ind = torch.randint(0, x.size(0), size=(5_079_670,), device='cuda') src = torch.randn((5_079_670, 128), device='cuda', dtype=torch.float) for _ in range(20): x.index_add_(0, ind, src) start_time = time.time() for i in range(100): x.index_add_(0, ind, src) torch.cuda.synchronize() end_time = time.time() mean_time = (end_time - start_time)/100 print(f"Avg time for index_add_: {mean_time * 1e6:.2f} us") ``` Perf numbers: ``` Before: Avg time for index_add_: 25652.16 us After: Avg time for index_add_: 2675.15 us ``` Co-author: @amd-hhashemi
1 parent 44c0e44 commit faae1f3

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

aten/src/ATen/native/cuda/KernelUtils.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd(
276276
}
277277
// not coalsced, so now let try to capture lane-matches...
278278

279+
if (numel > 16 /*<-hueristic threshold*/ * 64 ) {
280+
// well shucks, unlikely to capture same-dest atomics in a wave.
281+
// fall back to direct fastAtomic...
282+
fastAtomicAdd(self_ptr, index, numel, value, true);
283+
return;
284+
}
285+
279286
auto mask = __match_any_sync(__activemask(), (int64_t)dst);
280287
int leader = __ffsll(mask) - 1; // select a leader
281288
scalar_t crnt_val = (scalar_t)0;

0 commit comments

Comments
 (0)