Skip to content

Commit c43bc26

Browse files
Defining amdgpu.scaled_mfma op
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent af223bc commit c43bc26

File tree

4 files changed

+198
-3
lines changed

4 files changed

+198
-3
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,4 +830,52 @@ def AMDGPU_GatherToLDSOp :
830830
let hasVerifier = 1;
831831
}
832832

833+
def AMDGPU_ScaledMFMAOp :
834+
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
835+
Pure]>,
836+
Arguments<(ins
837+
I32Attr:$m,
838+
I32Attr:$n,
839+
I32Attr:$k,
840+
MFMAInTypes:$sourceA,
841+
MFMAInTypes:$sourceB,
842+
MFMAOutTypes:$destC,
843+
I32Attr:$scaleA,
844+
I32Attr:$scaleB,
845+
I32Attr:$opselA,
846+
I32Attr:$opselB)>,
847+
Results<(outs MFMAOutTypes: $destD)> {
848+
let summary = "MLIR wrapper for CDNA mfma instructions";
849+
let description = [{
850+
The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
851+
for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
852+
multiple outer products in order to allow fast matrix multiplication.
853+
854+
The wrapper will select an appropriate `mfma` instruction, if one is available,
855+
based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
856+
types of the source and destination arguments.
857+
858+
Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
859+
intrinsics that take an integer type of width `4K`. For example,
860+
one can provide a vector<4xi8> as an argument to an MFMA instruction that
861+
logically takes 4 i8s but whose intrinsics are specified to take an i32.
862+
In these cases, the bytes in the vector will be concatenated in little-endian
863+
order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
864+
865+
This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
866+
- `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
867+
fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
868+
size.
869+
- `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
870+
are omitted from this wrapper.
871+
- The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for
872+
double-precision operations on gfx94x and so are not included here.
873+
}];
874+
let assemblyFormat = [{
875+
$sourceA `*` $sourceB `+` $destC
876+
attr-dict
877+
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
878+
}];
879+
let hasVerifier = 1;
880+
}
833881
#endif // AMDGPU

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
833833
mfma.getBlocks(), chipset);
834834
}
835835

836+
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
837+
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
838+
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
839+
smfma.getSourceB().getType(),
840+
smfma.getDestC().getType(), smfma.getM(),
841+
smfma.getN(), smfma.getK(), 1u, chipset);
842+
}
843+
836844
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
837845
/// if one exists. This includes checking to ensure the intrinsic is supported
838846
/// on the architecture you are compiling for.
@@ -954,6 +962,54 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
954962
}
955963
};
956964

965+
struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
966+
ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
967+
: ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
968+
969+
Chipset chipset;
970+
971+
LogicalResult
972+
matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
973+
ConversionPatternRewriter &rewriter) const override {
974+
Location loc = op.getLoc();
975+
Type outType = typeConverter->convertType(op.getDestD().getType());
976+
Type intrinsicOutType = outType;
977+
if (auto outVecType = dyn_cast<VectorType>(outType))
978+
if (outVecType.getElementType().isBF16())
979+
intrinsicOutType = outVecType.clone(rewriter.getI16Type());
980+
981+
if (chipset.majorVersion != 9 || chipset < kGfx908)
982+
return op->emitOpError("Scaled MFMA only supported on gfx908+");
983+
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
984+
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
985+
if (!maybeScaledIntrinsic.has_value())
986+
return op.emitOpError(
987+
"no intrinsic matching Scaled MFMA size on given chipset");
988+
989+
StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
990+
OperationState loweredOp(loc, intrinsicName);
991+
loweredOp.addTypes(intrinsicOutType);
992+
loweredOp.addOperands(
993+
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
994+
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
995+
adaptor.getDestC()});
996+
Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA());
997+
Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
998+
Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
999+
Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
1000+
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1001+
loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
1002+
createI32Constant(rewriter, loc, bTypeCode),
1003+
/*scale A byte=*/opselA, /*scale A=*/scaleA,
1004+
/*scale B byte=*/opselB, /*scale B=*/scaleB});
1005+
Value lowered = rewriter.create(loweredOp)->getResult(0);
1006+
if (outType != intrinsicOutType)
1007+
lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
1008+
rewriter.replaceOp(op, lowered);
1009+
return success();
1010+
}
1011+
};
1012+
9571013
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
9581014
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
9591015
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -1474,8 +1530,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
14741530
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
14751531
ROCDL::RawPtrBufferAtomicCmpSwap>,
14761532
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1477-
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1478-
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
1479-
GatherToLDSOpLowering>(converter, chipset);
1533+
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1534+
ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1535+
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1536+
chipset);
14801537
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
14811538
}

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,49 @@ LogicalResult GatherToLDSOp::verify() {
506506
return success();
507507
}
508508

509+
LogicalResult ScaledMFMAOp::verify() {
510+
unsigned opselA = getOpselA();
511+
unsigned opselB = getOpselB();
512+
513+
opselA >>= 8;
514+
opselB >>= 8;
515+
516+
if (opselA != 0)
517+
return emitOpError("Opsel A must be a zero extended 8 bit value.");
518+
519+
if (opselB != 0)
520+
return emitOpError("Opsel B must be a zero extended 8 bit value.");
521+
522+
auto validType = [&](Type mlirElemType) {
523+
return llvm::TypeSwitch<Type, bool>(mlirElemType)
524+
.Case([](Float8E4M3FNType) { return true; })
525+
.Case([](Float8E5M2Type) { return true; })
526+
.Case([](Float6E2M3FNType) { return true; })
527+
.Case([](Float6E3M2FNType) { return true; })
528+
.Case([](Float4E2M1FNType) { return true; })
529+
.Default([](Type) { return false; });
530+
};
531+
532+
Type aType = getSourceA().getType();
533+
Type bType = getSourceB().getType();
534+
aType = getElementTypeOrSelf(aType);
535+
bType = getElementTypeOrSelf(bType);
536+
if (!validType(aType))
537+
return emitOpError("Source A must be of element type fp4, fp6 or fp8.");
538+
if (!validType(bType))
539+
return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
540+
541+
unsigned m = getM();
542+
unsigned n = getN();
543+
unsigned k = getK();
544+
bool tileConfig1 = (m == n && n == 32 && k == 64);
545+
bool tileConfig2 = (m == n && n == 16 && k == 128);
546+
if (!tileConfig1 && !tileConfig2)
547+
return emitOpError("Invalid tile size for scaled mfma.");
548+
549+
return success();
550+
}
551+
509552
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
510553

511554
#define GET_ATTRDEF_CLASSES

mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5151

5252
func.return
5353
}
54+
55+
func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
56+
%arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
57+
%arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
58+
%arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>) {
59+
60+
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
61+
// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
62+
// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
63+
64+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
65+
amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
66+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
67+
amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
68+
69+
// CHECK: llvm.bitcast
70+
71+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
72+
amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
73+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
74+
amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
75+
76+
// CHECK: llvm.bitcast
77+
78+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
79+
amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
80+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
81+
amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
82+
83+
// CHECK: llvm.bitcast
84+
// CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
85+
86+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
87+
amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
88+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
89+
amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
90+
91+
// CHECK: llvm.bitcast
92+
// CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
93+
94+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
95+
amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
96+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
97+
amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
98+
99+
func.return
100+
}

0 commit comments

Comments
 (0)