Skip to content

Commit f7e6775

Browse files
authored
[AMD] Pass down atomics memscope through lowering (triton-lang#5580)
# Overview Atomics in triton have two optional attributes: 1) `sem` -- describing the memory semantics of the operation 2) `scope` -- describing which threads will see the effect of a memory operation (e.g., GPU, CTA) Presently, the `scope` is ignored by the AMD backend and defaults to `agent`-scope in the emitted LLVM (which roughly corresponds to `gpu` memscope in triton). This is correct (in most cases? maybe not all?), as this is a "stricter" scope than CTA (and I'm guessing it is rare that system scope is needed for AMD kernels, so no bugs have shown up). That being said, emitting atomics at CTA scope can be more efficient since there can be fewer cache invalidations/barriers. I think that this is fixable by just passing through the attribute to the generated `llvm.atomicrmw` op. There are some additional optimizations potentially possible (e.g., !amdgpu.no.remote.memory, since Triton doesn't support this today), but it isn't clear to me if those would have any real impact on end-to-end performance and those optimizations would be specific to the `sys`-scope that doesn't appear to be frequently used. # Testing I added a lit test to ensure that the generated LLVM instructions have the correct sem/scope attributes for atomicrmw, but I also ran the following 386 unit tests locally on an MI300x: ```bash pytest test/unit/language/test_core.py -k test_atomic_ ``` I then locally ran some kernels with the scope set to CTA/SYSTEM to make sure that they worked.
1 parent a3095b3 commit f7e6775

File tree

2 files changed

+81
-6
lines changed

2 files changed

+81
-6
lines changed

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,46 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
209209
tt.return
210210
}
211211
}
212+
213+
214+
// -----
215+
216+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
217+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
218+
// CHECK-LABEL: atomicrmw_scope_memsemantics
219+
tt.func @atomicrmw_scope_memsemantics(%arg0 : tensor<128x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<128xi1, #blocked0>, %arg2 : tensor<128xf32, #blocked0>) {
220+
// relaxed
221+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} monotonic
222+
%0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
223+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic
224+
%1 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
225+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) monotonic
226+
%2 = tt.atomic_rmw fadd, relaxed, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
227+
228+
// acquire
229+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} acquire
230+
%3 = tt.atomic_rmw fadd, acquire, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
231+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) acquire
232+
%4 = tt.atomic_rmw fadd, acquire, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
233+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) acquire
234+
%5 = tt.atomic_rmw fadd, acquire, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
235+
236+
// release
237+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} release
238+
%6 = tt.atomic_rmw fadd, release, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
239+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) release
240+
%7 = tt.atomic_rmw fadd, release, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
241+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) release
242+
%8 = tt.atomic_rmw fadd, release, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
243+
244+
// acq_rel
245+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} acq_rel
246+
%9 = tt.atomic_rmw fadd, acq_rel, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
247+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) acq_rel
248+
%10 = tt.atomic_rmw fadd, acq_rel, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
249+
// CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) acq_rel
250+
%11 = tt.atomic_rmw fadd, acq_rel, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0>
251+
252+
tt.return
253+
}
254+
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,29 @@ struct LoadStoreConversionBase {
227227
return axisAnalysisPass.getPtrAlignment(ptr);
228228
}
229229

230+
std::optional<const std::string>
231+
getAMDGPUMemScopeStr(MemSyncScope scope) const {
232+
// See: https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
233+
auto scopeStr = "";
234+
switch (scope) {
235+
case MemSyncScope::SYSTEM:
236+
// The default AMDHSA LLVM Sync Scope is "system", so no string is
237+
// provided here
238+
scopeStr = "";
239+
break;
240+
case MemSyncScope::GPU:
241+
scopeStr = "agent";
242+
break;
243+
case MemSyncScope::CTA:
244+
scopeStr = "workgroup";
245+
break;
246+
default:
247+
return std::nullopt;
248+
}
249+
250+
return scopeStr;
251+
}
252+
230253
protected:
231254
const AMD::TargetInfo &targetInfo;
232255
ModuleAxisInfoAnalysis &axisAnalysisPass;
@@ -601,6 +624,10 @@ struct AtomicCASOpConversion
601624

602625
auto memOrdering = op.getSem();
603626
auto atomicMemOrdering = getMemoryOrdering(memOrdering);
627+
auto scope = op.getScope();
628+
auto scopeStr = getAMDGPUMemScopeStr(scope);
629+
if (!scopeStr)
630+
return failure();
604631

605632
// deal with tensor or scalar
606633
auto valueTy = op.getResult().getType();
@@ -643,7 +670,7 @@ struct AtomicCASOpConversion
643670
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
644671
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
645672
loc, casPtr, casCmp, casVal, successOrdering, failureOrdering,
646-
StringRef("agent"));
673+
StringRef(scopeStr.value()));
647674

648675
// Extract the new_loaded value from the pair.
649676
Value ret = extract_val(valueElemTy, cmpxchg, i);
@@ -852,8 +879,13 @@ struct AtomicRMWOpConversion
852879
mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0)));
853880

854881
auto memOrdering = op.getSem();
882+
auto scope = op.getScope();
855883
auto atomicMemOrdering = getMemoryOrdering(memOrdering);
856884

885+
auto scopeStr = getAMDGPUMemScopeStr(scope);
886+
if (!scopeStr)
887+
return failure();
888+
857889
auto vecTy = vec_ty(valueElemTy, vec);
858890
auto retType = vec == 1 ? valueElemTy : vecTy;
859891
retType = useDppForPackedF16 ? packF16Ty : retType;
@@ -907,11 +939,11 @@ struct AtomicRMWOpConversion
907939
auto maybeKind = matchAtomicOp(atomicRmwAttr);
908940
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
909941
// atomics for MI-* series of AMD GPU.
910-
Value atom =
911-
rewriter
912-
.create<LLVM::AtomicRMWOp>(loc, *maybeKind, rmwPtr, operand,
913-
atomicMemOrdering, StringRef("agent"))
914-
.getResult();
942+
Value atom = rewriter
943+
.create<LLVM::AtomicRMWOp>(loc, *maybeKind, rmwPtr,
944+
operand, atomicMemOrdering,
945+
StringRef(scopeStr.value()))
946+
.getResult();
915947
if (!tensorTy) {
916948
if (atomicNeedsSharedMemory(op.getResult())) {
917949
Value atomPtr =

0 commit comments

Comments
 (0)