Skip to content

Commit 29912c0

Browse files
authored
[BACKEND] Promote tl.atomic_add to PTX ld.acquire when possible (intel#5187)
To optimize the case tl.atomic_add(ptr, 0) for scalars, there is a new path for lowering to PTX `ld.acquire.scope` (`.cta`, `.gpu`, `.sys`) It does this by lowering to `nvgpu.ld_acquire` from the TTGIR::AtomicRMW lowering, then subsequently lowering to an LLVM inline_ptx of `ld.acquire` for NVGP::LoadAcquireOp lowering. The purpose is to generate better code for synchronizing groups of threads during a cooperative thread launch. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/python/test` for end-to-end tests
1 parent ad16e3d commit 29912c0

File tree

4 files changed

+192
-1
lines changed

4 files changed

+192
-1
lines changed

test/Conversion/atomic_ldst.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s --check-prefix=CHECK-TTG2NVGPU
2+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 --convert-nv-gpu-to-llvm 2>&1 | FileCheck %s --check-prefix=CHECK-NVGPU2LLVM
3+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
4+
tt.func public @kernel_r(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
5+
%cst = arith.constant 0.000000e+00 : f32
6+
%true = arith.constant true
7+
%c128_i32 = arith.constant 128 : i32
8+
%c512_i32 = arith.constant 512 : i32
9+
%0 = tt.get_program_id x : i32
10+
%1 = arith.muli %0, %c128_i32 : i32
11+
%2 = arith.cmpi slt, %1, %c512_i32 : i32
12+
13+
// CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, gpu
14+
// CHECK-NVGPU2LLVM: ld.global.gpu.acquire.b32
15+
%3 = tt.atomic_rmw fadd, acquire, gpu, %arg0, %cst, %2 : (!tt.ptr<f32>, f32, i1) -> f32
16+
tt.store %arg0, %3 : !tt.ptr<f32>
17+
18+
// CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, cta
19+
// CHECK-NVGPU2LLVM: ld.global.cta.acquire.b32
20+
%4 = tt.atomic_rmw fadd, acquire, cta, %arg0, %cst, %true : (!tt.ptr<f32>, f32, i1) -> f32
21+
tt.store %arg0, %4 : !tt.ptr<f32>
22+
23+
// CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, sys
24+
// CHECK-NVGPU2LLVM: ld.global.sys.acquire.b32
25+
%5 = tt.atomic_rmw fadd, acquire, sys, %arg0, %cst, %2 : (!tt.ptr<f32>, f32, i1) -> f32
26+
tt.store %arg0, %5 : !tt.ptr<f32>
27+
tt.return
28+
}
29+
}

third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,33 @@ include "NVGPUAttrDefs.td"
3232
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
3333
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
3434

35+
36+
def NVGPU_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
37+
def NVGPU_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
38+
def NVGPU_ScalarLike : AnyTypeOf<[NVGPU_Float, NVGPU_Int]>;
39+
40+
41+
def NVGPU_MemSemanticAttr : I32EnumAttr<
42+
"MemSemantic", "",
43+
[
44+
I32EnumAttrCase<"RELAXED", 1, "relaxed">,
45+
I32EnumAttrCase<"ACQUIRE", 2, "acquire">,
46+
I32EnumAttrCase<"RELEASE", 3, "release">,
47+
I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">,
48+
]> {
49+
let cppNamespace = "::mlir::triton::nvgpu";
50+
}
51+
52+
def NVGPU_MemSyncScopeAttr : I32EnumAttr<
53+
"MemSyncScope", "",
54+
[
55+
I32EnumAttrCase<"GPU", 1, "gpu">,
56+
I32EnumAttrCase<"CTA", 2, "cta">,
57+
I32EnumAttrCase<"SYSTEM", 3, "sys">,
58+
]> {
59+
let cppNamespace = "::mlir::triton::nvgpu";
60+
}
61+
3562
class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
3663
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;
3764

@@ -123,4 +150,15 @@ def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
123150
let assemblyFormat = "attr-dict";
124151
}
125152

153+
def NVGPU_LoadAcquireOp : NVGPU_Op<"ld_acquire", [MemoryEffects<[MemRead]>]> {
154+
let arguments = (
155+
ins LLVM_PointerGlobal:$addr,
156+
Optional<I1>:$mask,
157+
NVGPU_MemSemanticAttr:$sem,
158+
NVGPU_MemSyncScopeAttr:$scope
159+
);
160+
let results = (outs NVGPU_ScalarLike:$result);
161+
let assemblyFormat = "$sem `,` $scope `,` $addr (`,` $mask^)? attr-dict `:` functional-type($addr, $result)";
162+
}
163+
126164
#endif

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,46 @@ class LoadMatrixOpPattern
370370
}
371371
};
372372

373+
class LoadAcquireOpPattern : public OpRewritePattern<ttn::LoadAcquireOp> {
374+
public:
375+
using OpRewritePattern<ttn::LoadAcquireOp>::OpRewritePattern;
376+
377+
LogicalResult matchAndRewrite(ttn::LoadAcquireOp op,
378+
PatternRewriter &rewriter) const override {
379+
auto loc = op->getLoc();
380+
Type valueTy = op.getType();
381+
const unsigned valueNBits = std::max(8u, valueTy.getIntOrFloatBitWidth());
382+
const size_t maxWordWidth = std::max<size_t>(32, valueNBits);
383+
const size_t width = std::min((size_t)valueNBits, maxWordWidth);
384+
385+
const std::string writeConstraint =
386+
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
387+
PTXBuilder ptxBuilder;
388+
bool init = true;
389+
auto *dstOpr = ptxBuilder.newOperand(writeConstraint, init); // =r operation
390+
auto *addrOpr =
391+
ptxBuilder.newAddrOperand(op.getAddr(), "l", 0 /* in_off */);
392+
auto &ld =
393+
ptxBuilder.create<>("ld")
394+
->global()
395+
.o("cta", op.getScope() == triton::nvgpu::MemSyncScope::CTA)
396+
.o("gpu", op.getScope() == triton::nvgpu::MemSyncScope::GPU)
397+
.o("sys", op.getScope() == triton::nvgpu::MemSyncScope::SYSTEM)
398+
.o("acquire", op.getSem() == triton::nvgpu::MemSemantic::ACQUIRE)
399+
.o("relaxed", op.getSem() == triton::nvgpu::MemSemantic::RELAXED)
400+
.b(width);
401+
ld(dstOpr, addrOpr).maybePredicate(op.getMask(), "b");
402+
403+
// Create inline ASM signature
404+
Type retTy = IntegerType::get(getContext(), width);
405+
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
406+
ret = bitcast(ret, op.getType());
407+
408+
rewriter.replaceOp(op, {ret});
409+
return success();
410+
}
411+
};
412+
373413
class WGMMAWaitGroupOpPattern : public OpRewritePattern<ttn::WGMMAWaitGroupOp> {
374414
public:
375415
using OpRewritePattern<ttn::WGMMAWaitGroupOp>::OpRewritePattern;
@@ -608,7 +648,7 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase<ConvertNVGPUToLLVM> {
608648

609649
patterns.add<FenceAsyncSharedOpPattern, LoadMatrixOpPattern,
610650
StoreMatrixOpPattern, ClusterArriveOpPattern, WGMMAOpPattern,
611-
WGMMAWaitGroupOpPattern>(context);
651+
LoadAcquireOpPattern, WGMMAWaitGroupOpPattern>(context);
612652

613653
if (applyPatternsGreedily(mod, std::move(patterns)).failed())
614654
signalPassFailure();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Dialect/NVGPU/IR/Dialect.h"
12
#include "TargetInfo.h"
23
#include "mlir/IR/Matchers.h"
34
#include "mlir/IR/TypeUtilities.h"
@@ -24,6 +25,9 @@ using ::mlir::triton::gpu::getShapePerCTA;
2425
using ::mlir::triton::gpu::getTotalElemsPerThread;
2526
using ::mlir::triton::gpu::SharedEncodingAttr;
2627

28+
// Toggle this to work around Cooperative Grid Launch ld.acquire optimized path
29+
static constexpr bool disableLDAcquireLowering = false;
30+
2731
namespace {
2832

2933
llvm::MapVector<StringAttr, int32_t> getAllFreeVarMasks(MLIRContext *ctx) {
@@ -696,6 +700,48 @@ struct AtomicRMWOpConversion
696700
(elementType.isF16() || elementType.isBF16() || elementType.isF32());
697701
}
698702

703+
bool isPromotableToNVPTXLD(triton::AtomicRMWOp op) const {
704+
if (disableLDAcquireLowering)
705+
return false;
706+
707+
Type valueTy =
708+
getTypeConverter()->convertType(getElementTypeOrSelf(op.getType()));
709+
710+
if (!valueTy.isIntOrFloat())
711+
return false;
712+
if (op.getSem() != triton::MemSemantic::ACQUIRE &&
713+
op.getSem() != triton::MemSemantic::RELAXED)
714+
return false;
715+
if (op.getScope() != triton::MemSyncScope::CTA &&
716+
op.getScope() != triton::MemSyncScope::GPU &&
717+
op.getScope() != triton::MemSyncScope::SYSTEM)
718+
return false;
719+
720+
if (op.getAtomicRmwOp() != RMWOp::ADD && op.getAtomicRmwOp() != RMWOp::FADD)
721+
return false;
722+
if (isa<RankedTensorType>(op.getType()))
723+
return false;
724+
if (!op.getVal().getDefiningOp())
725+
return false;
726+
if (!isa<arith::ConstantOp>(op.getVal().getDefiningOp()))
727+
return false;
728+
729+
auto constOp = cast<arith::ConstantOp>(op.getVal().getDefiningOp());
730+
if (!isa<FloatAttr>(constOp.getValueAttr()) &&
731+
!isa<IntegerAttr>(constOp.getValueAttr()))
732+
return false;
733+
734+
if (auto attr = dyn_cast_or_null<FloatAttr>(constOp.getValueAttr()))
735+
if (!attr.getValue().isZero())
736+
return false;
737+
738+
if (auto attr = dyn_cast_or_null<IntegerAttr>(constOp.getValueAttr()))
739+
if (!attr.getValue().isZero())
740+
return false;
741+
742+
return true;
743+
}
744+
699745
LogicalResult
700746
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
701747
ConversionPatternRewriter &rewriter) const override {
@@ -767,6 +813,17 @@ struct AtomicRMWOpConversion
767813

768814
auto packedTy = vec_ty(valueElemTy, packed);
769815
SmallVector<Value> resultVals(elemsPerThread);
816+
817+
// Lower AtomicRMWOp to a ld.acquire if possible
818+
std::unordered_map<triton::MemSyncScope, triton::nvgpu::MemSyncScope>
819+
ScopeMap = {
820+
{triton::MemSyncScope::CTA, triton::nvgpu::MemSyncScope::CTA},
821+
{triton::MemSyncScope::GPU, triton::nvgpu::MemSyncScope::GPU},
822+
{triton::MemSyncScope::SYSTEM,
823+
triton::nvgpu::MemSyncScope::SYSTEM}};
824+
const bool doPTXLDPromotion = isPromotableToNVPTXLD(op) && vec == 1 &&
825+
packed == 1 && ScopeMap.count(op.getScope());
826+
770827
for (size_t i = 0; i < elemsPerThread; i += vec * packed) {
771828
if (auto canonicalStart = getCanonicalIndex(i, regMask);
772829
canonicalStart != i) {
@@ -780,6 +837,33 @@ struct AtomicRMWOpConversion
780837
Value rmwPtr = ptrElements[i];
781838
Value pred = llMask ? maybeAnd(rewriter, loc, threadPred, maskElements[i])
782839
: threadPred;
840+
841+
if (doPTXLDPromotion) {
842+
Type covertedValueTy =
843+
getTypeConverter()->convertType(getElementTypeOrSelf(op.getType()));
844+
auto loadAcquireOp = rewriter.create<triton::nvgpu::LoadAcquireOp>(
845+
op.getLoc(), covertedValueTy, rmwPtr, pred,
846+
op.getSem() == triton::MemSemantic::ACQUIRE
847+
? triton::nvgpu::MemSemantic::ACQUIRE
848+
: triton::nvgpu::MemSemantic::RELAXED,
849+
ScopeMap[op.getScope()]);
850+
851+
auto ASMReturnTy = void_ty(ctx);
852+
if (!atomicNeedsSharedMemory(op.getResult())) {
853+
rewriter.eraseOp(op);
854+
return success();
855+
}
856+
Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo,
857+
op.getOperation());
858+
atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3));
859+
// Only threads with rmwMask = True store the result
860+
targetInfo.storeShared(rewriter, loc, atomPtr, loadAcquireOp, pred);
861+
createBarrier(rewriter, loc, numCTAs);
862+
Value ret = load(valueElemTy, atomPtr);
863+
rewriter.replaceOp(op, {ret});
864+
continue;
865+
}
866+
783867
std::string sTy;
784868
PTXBuilder ptxBuilderAtomicRMW;
785869
// 16-bit -> "h", 32-bit -> "r", 64-bit -> "l"

0 commit comments

Comments
 (0)