Skip to content

Commit 2ea8e70

Browse files
Hardcode84mahesh-attarde
authored andcommitted
[mlir][amdgpu] Add rocdl.s.waitcnt wrapper (llvm#149670)
The main motivations is to pass vmcnt/expcnt/lgkmcnt values directly (similar to the asm format) and delegate architecture-dependent bitpacking to the amdgpu->rocdl lowering. --------- Signed-off-by: Ivan Butygin <[email protected]>
1 parent 247adce commit 2ea8e70

File tree

4 files changed

+191
-3
lines changed

4 files changed

+191
-3
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,29 @@ def AMDGPU_SchedBarrierOp :
719719
}];
720720
}
721721

722+
def AMDGPU_MemoryCounterWaitOp :
723+
AMDGPU_Op<"memory_counter_wait">,
724+
Arguments<(ins
725+
OptionalAttr<I32Attr>:$load,
726+
OptionalAttr<I32Attr>:$store,
727+
OptionalAttr<I32Attr>:$ds,
728+
OptionalAttr<I32Attr>:$exp
729+
)>
730+
{
731+
let summary = "Wait for specified hardware counters";
732+
let description = [{
733+
Wait for the specified counters to be less-than or equal-to the provided
734+
values before continuing.
735+
736+
Counters can lower to different instructions on different architectires,
737+
including clamping to the some HW supported max value or combining multiple
738+
counters into one.
739+
}];
740+
let assemblyFormat = [{
741+
oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` ) attr-dict
742+
}];
743+
}
744+
722745
def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB",
723746
"The possible permutations of the lanes storing B available in an MFMA",
724747
[

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,112 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
419419
}
420420
};
421421

422+
// TODO: AMDGPU backend already have all this bitpacking logic, we should move
423+
// it to some common place.
424+
/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
425+
/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
426+
/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
427+
/// Vmcnt = Waitcnt[15:10] (gfx11)
428+
/// Expcnt = Waitcnt[6:4] (pre-gfx11)
429+
/// Expcnt = Waitcnt[2:0] (gfx11)
430+
/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
431+
/// Lgkmcnt = Waitcnt[13:8] (gfx10)
432+
/// Lgkmcnt = Waitcnt[9:4] (gfx11)
433+
static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
434+
unsigned expcnt, unsigned lgkmcnt) {
435+
if (chipset.majorVersion < 9) {
436+
vmcnt = std::min(15u, vmcnt);
437+
expcnt = std::min(7u, expcnt);
438+
lgkmcnt = std::min(15u, lgkmcnt);
439+
return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
440+
}
441+
if (chipset.majorVersion == 9) {
442+
vmcnt = std::min(63u, vmcnt);
443+
expcnt = std::min(7u, expcnt);
444+
lgkmcnt = std::min(15u, lgkmcnt);
445+
unsigned lowBits = vmcnt & 0xF;
446+
unsigned highBits = (vmcnt >> 4) << 14;
447+
unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
448+
return lowBits | highBits | otherCnts;
449+
}
450+
if (chipset.majorVersion == 10) {
451+
vmcnt = std::min(63u, vmcnt);
452+
expcnt = std::min(7u, expcnt);
453+
lgkmcnt = std::min(63u, lgkmcnt);
454+
unsigned lowBits = vmcnt & 0xF;
455+
unsigned highBits = (vmcnt >> 4) << 14;
456+
unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
457+
return lowBits | highBits | otherCnts;
458+
}
459+
if (chipset.majorVersion == 11) {
460+
vmcnt = std::min(63u, vmcnt);
461+
expcnt = std::min(7u, expcnt);
462+
lgkmcnt = std::min(63u, lgkmcnt);
463+
return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
464+
}
465+
return failure();
466+
}
467+
468+
struct MemoryCounterWaitOpLowering
469+
: public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
470+
MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
471+
Chipset chipset)
472+
: ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
473+
chipset(chipset) {}
474+
475+
Chipset chipset;
476+
477+
LogicalResult
478+
matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
479+
ConversionPatternRewriter &rewriter) const override {
480+
if (chipset.majorVersion >= 12) {
481+
Location loc = op.getLoc();
482+
if (std::optional<int> ds = adaptor.getDs())
483+
rewriter.create<ROCDL::WaitDscntOp>(loc, *ds);
484+
485+
if (std::optional<int> load = adaptor.getLoad())
486+
rewriter.create<ROCDL::WaitLoadcntOp>(loc, *load);
487+
488+
if (std::optional<int> store = adaptor.getStore())
489+
rewriter.create<ROCDL::WaitStorecntOp>(loc, *store);
490+
491+
if (std::optional<int> exp = adaptor.getExp())
492+
rewriter.create<ROCDL::WaitExpcntOp>(loc, *exp);
493+
494+
rewriter.eraseOp(op);
495+
return success();
496+
}
497+
498+
auto getVal = [](Attribute attr) -> unsigned {
499+
if (attr)
500+
return cast<IntegerAttr>(attr).getInt();
501+
502+
// This value will be clamped to the maximum value for the chipset.
503+
return 1024;
504+
};
505+
unsigned ds = getVal(adaptor.getDsAttr());
506+
unsigned exp = getVal(adaptor.getExpAttr());
507+
508+
unsigned vmcnt = 1024;
509+
Attribute load = adaptor.getLoadAttr();
510+
Attribute store = adaptor.getStoreAttr();
511+
if (load && store) {
512+
vmcnt = getVal(load) + getVal(store);
513+
} else if (load) {
514+
vmcnt = getVal(load);
515+
} else if (store) {
516+
vmcnt = getVal(store);
517+
}
518+
519+
FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
520+
if (failed(waitcnt))
521+
return op.emitOpError("unsupported chipset");
522+
523+
rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
524+
return success();
525+
}
526+
};
527+
422528
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
423529
LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
424530
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
@@ -1825,9 +1931,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
18251931
ROCDL::RawPtrBufferAtomicUminOp>,
18261932
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
18271933
ROCDL::RawPtrBufferAtomicCmpSwap>,
1828-
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1829-
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1830-
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1934+
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
1935+
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
1936+
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
18311937
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
18321938
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
18331939
TransposeLoadOpLowering>(converter, chipset);
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9
2+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10
3+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11
4+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12
5+
6+
// CHECK-LABEL: func @memory_counter_wait
7+
func.func @memory_counter_wait() {
8+
// GFX9: rocdl.s.waitcnt 53119
9+
// GFX10: rocdl.s.waitcnt 65407
10+
// GFX11: rocdl.s.waitcnt 65527
11+
// GFX12-NOT: rocdl.s.wait.loadcnt
12+
// GFX12-NOT: rocdl.s.wait.storecnt
13+
// GFX12-NOT: rocdl.s.wait.expcnt
14+
// GFX12-NOT: rocdl.s.wait.dscnt
15+
amdgpu.memory_counter_wait
16+
17+
// GFX9: rocdl.s.waitcnt 3952
18+
// GFX10: rocdl.s.waitcnt 16240
19+
// GFX11: rocdl.s.waitcnt 1015
20+
// GFX12: rocdl.s.wait.loadcnt 0
21+
amdgpu.memory_counter_wait load(0)
22+
23+
// GFX9: rocdl.s.waitcnt 3952
24+
// GFX10: rocdl.s.waitcnt 16240
25+
// GFX11: rocdl.s.waitcnt 1015
26+
// GFX12: rocdl.s.wait.storecnt 0
27+
amdgpu.memory_counter_wait store(0)
28+
29+
// GFX9: rocdl.s.waitcnt 53007
30+
// GFX10: rocdl.s.waitcnt 65295
31+
// GFX11: rocdl.s.waitcnt 65520
32+
// GFX12: rocdl.s.wait.expcnt 0
33+
amdgpu.memory_counter_wait exp(0)
34+
35+
// GFX9: rocdl.s.waitcnt 49279
36+
// GFX10: rocdl.s.waitcnt 49279
37+
// GFX11: rocdl.s.waitcnt 64519
38+
// GFX12: rocdl.s.wait.dscnt 0
39+
amdgpu.memory_counter_wait ds(0)
40+
41+
return
42+
}

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,20 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %
548548
amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
549549
func.return
550550
}
551+
552+
// CHECK-LABEL: func @memory_counter_wait
553+
func.func @memory_counter_wait() {
554+
// CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
555+
// CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1)
556+
// CHECK: amdgpu.memory_counter_wait load(1)
557+
// CHECK: amdgpu.memory_counter_wait store(2)
558+
// CHECK: amdgpu.memory_counter_wait ds(3)
559+
// CHECK: amdgpu.memory_counter_wait exp(4)
560+
amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
561+
amdgpu.memory_counter_wait exp(1) store(2) ds(3) load(4)
562+
amdgpu.memory_counter_wait load(1)
563+
amdgpu.memory_counter_wait store(2)
564+
amdgpu.memory_counter_wait ds(3)
565+
amdgpu.memory_counter_wait exp(4)
566+
func.return
567+
}

0 commit comments

Comments
 (0)