Skip to content

Commit f3bc55c

Browse files
committed
[mlir][amdgpu] Add amdgpu.waitcnt wrapper
The main motivations is to pass vmcnt/expcnt/lgkmcnt values directly and delegate architecture-dependent bitpacking to the amdgpu->rocdl lowering. Only gfx9 bitpacking support added as part of this commit.
1 parent 6db9b0d commit f3bc55c

File tree

4 files changed

+102
-3
lines changed

4 files changed

+102
-3
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,26 @@ def AMDGPU_SchedBarrierOp :
719719
}];
720720
}
721721

722+
def AMDGPU_WaitcntOp :
723+
AMDGPU_Op<"waitcnt">,
724+
Arguments<(ins
725+
OptionalAttr<I32Attr>:$vmcnt,
726+
OptionalAttr<I32Attr>:$expcnt,
727+
OptionalAttr<I32Attr>:$lgkmcnt
728+
)>
729+
{
730+
let summary = "Wrapper on ROCDL SWaitcntOp";
731+
let description = [{
732+
Covenience wrapper on `rocdl.s.waitcnt`. Hides the architecture specific
733+
bitpacking from user. Missing values will be assumed maximum values supported
734+
by the architecture. Large values will also be clamped to the maximum
735+
supported values.
736+
}];
737+
let assemblyFormat = [{
738+
(`vmcnt` `(` $vmcnt^ `)` )? (`expcnt` `(` $expcnt^ `)` )? (`lgkmcnt` `(` $lgkmcnt^ `)`)? attr-dict
739+
}];
740+
}
741+
722742
def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB",
723743
"The possible permutations of the lanes storing B available in an MFMA",
724744
[

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,52 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
419419
}
420420
};
421421

422+
// TODO: AMDGPU backend already have all this bitpacking logic, we should move
423+
// it to some common place.
424+
static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
425+
unsigned expcnt, unsigned lgkmcnt) {
426+
if (chipset.majorVersion == 9) {
427+
vmcnt = std::min(63u, vmcnt);
428+
expcnt = std::min(7u, expcnt);
429+
lgkmcnt = std::min(15u, lgkmcnt);
430+
unsigned lowBits = vmcnt & 0xF;
431+
unsigned highBits = (vmcnt >> 4) << 14;
432+
unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
433+
return lowBits | highBits | otherCnts;
434+
}
435+
return failure();
436+
}
437+
438+
struct WaitcntOpLowering : public ConvertOpToLLVMPattern<WaitcntOp> {
439+
WaitcntOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
440+
: ConvertOpToLLVMPattern<WaitcntOp>(converter), chipset(chipset) {}
441+
442+
Chipset chipset;
443+
444+
LogicalResult
445+
matchAndRewrite(WaitcntOp op, OpAdaptor adaptor,
446+
ConversionPatternRewriter &rewriter) const override {
447+
auto getVal = [](Attribute attr) -> unsigned {
448+
if (attr)
449+
return cast<IntegerAttr>(attr).getInt();
450+
451+
// This value will be clamped to the maximum value for the chipset.
452+
return 1024 * 1024;
453+
};
454+
unsigned vmcnt = getVal(adaptor.getVmcntAttr());
455+
unsigned expcnt = getVal(adaptor.getExpcntAttr());
456+
unsigned lgkmcnt = getVal(adaptor.getLgkmcntAttr());
457+
458+
FailureOr<unsigned> waitcnt =
459+
encodeWaitcnt(chipset, vmcnt, expcnt, lgkmcnt);
460+
if (failed(waitcnt))
461+
return op.emitOpError("unsupported chipset");
462+
463+
rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
464+
return success();
465+
}
466+
};
467+
422468
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
423469
LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
424470
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
@@ -1825,9 +1871,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
18251871
ROCDL::RawPtrBufferAtomicUminOp>,
18261872
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
18271873
ROCDL::RawPtrBufferAtomicCmpSwap>,
1828-
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1829-
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1830-
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1874+
AMDGPUDPPLowering, WaitcntOpLowering, LDSBarrierOpLowering,
1875+
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
1876+
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
18311877
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
18321878
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
18331879
TransposeLoadOpLowering>(converter, chipset);
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9
2+
// TODO: Add more chipsets support
3+
4+
5+
// CHECK-LABEL: func @waitcnt
6+
func.func @waitcnt() {
7+
// GFX9: rocdl.s.waitcnt 53119
8+
amdgpu.waitcnt
9+
10+
// GFX9: rocdl.s.waitcnt 3952
11+
amdgpu.waitcnt vmcnt(0)
12+
13+
// GFX9: rocdl.s.waitcnt 53007
14+
amdgpu.waitcnt expcnt(0)
15+
16+
// GFX9: rocdl.s.waitcnt 49279
17+
amdgpu.waitcnt lgkmcnt(0)
18+
19+
return
20+
}

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,16 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %
548548
amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
549549
func.return
550550
}
551+
552+
// CHECK-LABEL: func @waitcnt
553+
func.func @waitcnt() {
554+
// CHECK: amdgpu.waitcnt vmcnt(1) expcnt(2) lgkmcnt(3)
555+
// CHECK: amdgpu.waitcnt vmcnt(1)
556+
// CHECK: amdgpu.waitcnt expcnt(2)
557+
// CHECK: amdgpu.waitcnt lgkmcnt(3)
558+
amdgpu.waitcnt vmcnt(1) expcnt(2) lgkmcnt(3)
559+
amdgpu.waitcnt vmcnt(1)
560+
amdgpu.waitcnt expcnt(2)
561+
amdgpu.waitcnt lgkmcnt(3)
562+
func.return
563+
}

0 commit comments

Comments
 (0)