You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[AMD][Backend] Optimize Global AtomicRMW Ops (#7496)
Following the memory model semantics described in
triton-lang/triton#7292. The following code
```
@triton.jit
def atomic_kernel(x_ptr
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
t1 = tl.full((BLOCK_SIZE, ), 1, dtype=tl.bfloat16)
offsets = block_start + tl.arange(0, BLOCK_SIZE)
tl.atomic_add(x_ptr + offsets, t1)
```
Emits a single tt.atomic_rmw op that is lowered to multiple llvm.atomic
ops. For the case of global (i.e. non-buffer) atomics and the top level
tt.atomic_rmw op set to acquire release semantics (the default if no sem
value is explicitly set in the DSL level tl.atomic_add):
```
tt.atomic_rmw fadd, acq_rel
```
is lowered to (up to 8 for num_warps=1)
```
llvm.atomicrmw acq_rel
llvm.atomicrmw acq_rel
```
with every llvm level atomic_rmw inheriting the acq_rel. This causes the
llvm amdgcn backend to insert unnecessary buffer invalidates and L2
cache write backs.
In the buffer atomic case this is well handled:
```
buffer_wbl2 sc1
buffer_atomic_pk_add_bf16 v1, v0, s[0:3], 0 offen
buffer_atomic_pk_add_bf16 v1, v0, s[0:3], 0 offen offset:4
buffer_inv sc1
```
However for global atomics the following GCN is emitted:
```
buffer_wbl2 sc1
global_atomic_pk_add_bf16 v0, v1, s[0:1]
s_waitcnt vmcnt(0)
buffer_inv sc1
buffer_wbl2 sc1
global_atomic_pk_add_bf16 v0, v1, s[0:1] offset:4
s_waitcnt vmcnt(0)
buffer_inv sc1
```
causing a 2x-8x (depending on num_warps) slow down compared to the
buffer atomic case. After this PR the emitted GCN does not emit the
unnecessary buffer invalidates and L2 cache write backs:
```
buffer_wbl2 sc1
global_atomic_pk_add_bf16 v0, v1, s[0:1]
global_atomic_pk_add_bf16 v0, v1, s[0:1] offset:4
s_waitcnt vmcnt(0)
buffer_inv sc1
```
and the global atomic and buffer atomic performance match.
Co-authored w/ @karthik-man.
0 commit comments