Skip to content

Commit df39911

Browse files
authored
[AMD][Backend] Optimize Global AtomicRMW Ops (#7496)
Following the memory model semantics described in triton-lang/triton#7292. The following code ``` @triton.jit def atomic_kernel(x_ptr BLOCK_SIZE: tl.constexpr ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE t1 = tl.full((BLOCK_SIZE, ), 1, dtype=tl.bfloat16) offsets = block_start + tl.arange(0, BLOCK_SIZE) tl.atomic_add(x_ptr + offsets, t1) ``` Emits a single tt.atomic_rmw op that is lowered to multiple llvm.atomic ops. For the case of global (i.e. non-buffer) atomics and the top level tt.atomic_rmw op set to acquire release semantics (the default if no sem value is explicitly set in the DSL level tl.atomic_add): ``` tt.atomic_rmw fadd, acq_rel ``` is lowered to (up to 8 for num_warps=1) ``` llvm.atomicrmw acq_rel llvm.atomicrmw acq_rel ``` with every llvm level atomic_rmw inheriting the acq_rel. This causes the llvm amdgcn backend to insert unnecessary buffer invalidates and L2 cache write backs. In the buffer atomic case this is well handled: ``` buffer_wbl2 sc1 buffer_atomic_pk_add_bf16 v1, v0, s[0:3], 0 offen buffer_atomic_pk_add_bf16 v1, v0, s[0:3], 0 offen offset:4 buffer_inv sc1 ``` However for global atomics the following GCN is emitted: ``` buffer_wbl2 sc1 global_atomic_pk_add_bf16 v0, v1, s[0:1] s_waitcnt vmcnt(0) buffer_inv sc1 buffer_wbl2 sc1 global_atomic_pk_add_bf16 v0, v1, s[0:1] offset:4 s_waitcnt vmcnt(0) buffer_inv sc1 ``` causing a 2x-8x (depending on num_warps) slow down compared to the buffer atomic case. After this PR the emitted GCN does not emit the unnecessary buffer invalidates and L2 cache write backs: ``` buffer_wbl2 sc1 global_atomic_pk_add_bf16 v0, v1, s[0:1] global_atomic_pk_add_bf16 v0, v1, s[0:1] offset:4 s_waitcnt vmcnt(0) buffer_inv sc1 ``` and the global atomic and buffer atomic performance match. Co-authored w/ @karthik-man.
1 parent a89c5bd commit df39911

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,3 +475,74 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
475475
tt.return
476476
}
477477
}
478+
479+
// -----
480+
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
481+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
482+
tt.func public @atomic_kernel_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
483+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) release
484+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) acquire
485+
%cst = arith.constant dense<true> : tensor<1024xi1, #blocked>
486+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<1024xbf16, #blocked>
487+
%c1024_i32 = arith.constant 1024 : i32
488+
%0 = tt.get_program_id x : i32
489+
%1 = arith.muli %0, %c1024_i32 : i32
490+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
491+
%3 = tt.addptr %arg0, %1 : !tt.ptr<bf16>, i32
492+
%4 = tt.splat %3 : !tt.ptr<bf16> -> tensor<1024x!tt.ptr<bf16>, #blocked>
493+
%5 = tt.addptr %4, %2 : tensor<1024x!tt.ptr<bf16>, #blocked>, tensor<1024xi32, #blocked>
494+
%6 = tt.atomic_rmw fadd, acq_rel, gpu, %5, %cst_0, %cst : (tensor<1024x!tt.ptr<bf16>, #blocked>, tensor<1024xbf16, #blocked>, tensor<1024xi1, #blocked>) -> tensor<1024xbf16, #blocked>
495+
tt.return
496+
}
497+
}
498+
499+
// -----
500+
501+
#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
502+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
503+
tt.func public @atomic_kernel_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
504+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) release
505+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic
506+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic
507+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) acquire
508+
%cst = arith.constant dense<true> : tensor<1024xi1, #blocked>
509+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<1024xbf16, #blocked>
510+
%c1024_i32 = arith.constant 1024 : i32
511+
%0 = tt.get_program_id x : i32
512+
%1 = arith.muli %0, %c1024_i32 : i32
513+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
514+
%3 = tt.addptr %arg0, %1 : !tt.ptr<bf16>, i32
515+
%4 = tt.splat %3 : !tt.ptr<bf16> -> tensor<1024x!tt.ptr<bf16>, #blocked>
516+
%5 = tt.addptr %4, %2 : tensor<1024x!tt.ptr<bf16>, #blocked>, tensor<1024xi32, #blocked>
517+
%6 = tt.atomic_rmw fadd, acq_rel, gpu, %5, %cst_0, %cst : (tensor<1024x!tt.ptr<bf16>, #blocked>, tensor<1024xbf16, #blocked>, tensor<1024xi1, #blocked>) -> tensor<1024xbf16, #blocked>
518+
tt.return
519+
}
520+
}
521+
522+
// -----
523+
524+
#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
525+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
526+
tt.func public @atomic_kernel_fp32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
527+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) release
528+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic
529+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic
530+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic
531+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic
532+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic
533+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic
534+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) acquire
535+
%cst = arith.constant dense<true> : tensor<1024xi1, #blocked>
536+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked>
537+
%c1024_i32 = arith.constant 1024 : i32
538+
%0 = tt.get_program_id x : i32
539+
%1 = arith.muli %0, %c1024_i32 : i32
540+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
541+
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
542+
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
543+
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
544+
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
545+
%7 = tt.atomic_rmw fadd, acq_rel, gpu, %6, %cst_0, %cst : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xi1, #blocked>) -> tensor<1024xf32, #blocked>
546+
tt.return
547+
}
548+
}

third_party/amd/lib/TritonAMDGPUToLLVM/AtomicRMWOpsEmitter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ class AtomicRMWEmitter {
2222

2323
Value emitPairedAtomicForEvenTID(RewriterBase &rewriter, Value rmwPtr,
2424
Value valElem, Value rmwMask) const;
25+
void setAtomicOrdering(LLVM::AtomicOrdering memOrder) {
26+
this->memOrder = memOrder;
27+
}
2528

2629
private:
2730
const mlir::triton::AMD::TargetInfo &targetInfo;

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,6 +1529,29 @@ struct AtomicRMWOpConversion
15291529
valElement = vecVal;
15301530
}
15311531

1532+
// If we have a single tl.atomic_rmw that is lowered into multiple
1533+
// llvm.atomic_rmw, and we set the ordering for each to aql_rel (the
1534+
// default if no sem value is explicitly set in the DSL level
1535+
// tl.atomic_add. The llvm backend will insert extra buffer invalidates
1536+
// and L2 write backs causing a perforance degration. To avoid this we
1537+
// set the ordering to release for the first, acquire for the last, and
1538+
// relaxed for anything in between so that only a single set of
1539+
// buffer_inv and buffer_wbl2 instructions are inserted by the backend
1540+
// for any "cluster" of atomic ops.
1541+
if ((vec > 1 || elemsPerThread > 1) &&
1542+
op.getSem() == MemSemantic::ACQUIRE_RELEASE) {
1543+
if (i == 0) {
1544+
// First
1545+
emitter.setAtomicOrdering(LLVM::AtomicOrdering::release);
1546+
} else if (i == elemsPerThread - vec) {
1547+
// Last
1548+
emitter.setAtomicOrdering(LLVM::AtomicOrdering::acquire);
1549+
} else {
1550+
// Middle
1551+
emitter.setAtomicOrdering(LLVM::AtomicOrdering::monotonic);
1552+
}
1553+
}
1554+
15321555
Value retVal =
15331556
emitter.emitAtomicRMW(rewriter, ptrElements[i], valElement, rmwMask,
15341557
atomicSharedMemBase, enableIntraWaveReduce);
@@ -1548,6 +1571,7 @@ struct AtomicRMWOpConversion
15481571
Value atomPtr = *atomicSharedMemBase;
15491572
b.barrier();
15501573
Value ret = b.load(valueElemTy, atomPtr);
1574+
15511575
rewriter.replaceOp(op, {ret});
15521576
}
15531577
}

0 commit comments

Comments
 (0)