Skip to content

Commit 0291d87

Browse files
committed
Add scaled WMMA to AMDGPU
1 parent 2833d9f commit 0291d87

File tree

5 files changed

+467
-26
lines changed

5 files changed

+467
-26
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,15 @@ def MFMAOutTypes : AnyTypeOf<[F64,
959959
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
960960
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
961961
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
962+
963+
// scaled_wmma
964+
def ScaledWMMAInTypes
965+
: AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
966+
VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
967+
VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
968+
969+
def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F32]>]>;
970+
962971
// wmma
963972
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
964973
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
@@ -1226,6 +1235,72 @@ def AMDGPU_ScaledMFMAOp :
12261235
let hasCanonicalizer = 1;
12271236
}
12281237

1238+
def AMDGPU_ScaledWMMAOp
1239+
: AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>, Pure]>,
1240+
Arguments<(ins ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
1241+
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
1242+
ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
1243+
ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB,
1244+
ScaledWMMAOutTypes:$destC,
1245+
VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA,
1246+
ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$a_first_scale_lane,
1247+
VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleB,
1248+
ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$b_first_scale_lane)>,
1249+
Results<(outs ScaledWMMAOutTypes:$destD)> {
1250+
// TODO: E5M3FNU scales are supported, but there is not yet MLIR support for
1251+
// this datatype. Once we have support for that, update the scaleA and scaleB
1252+
// types here.
1253+
let summary = "MLIR wrapper for scaled wmma instructions";
1254+
let description = [{
1255+
The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
1256+
`wmma` instructions. These instructions perform matrix multiplication with
1257+
per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats.
1258+
1259+
The scale instructions support a block size of 16 or 32 and two tile sizes:
1260+
- 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
1261+
- 32x16x128 with f4 format only (output: vector<8xf32>)
1262+
1263+
Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values
1264+
(either f8E8M0FNU, or f8E4M3FN) that are packed into i32/i64 values during
1265+
lowering. The index attributes (`a_first_scale_lane`, `b_first_scale_lane`) select which register
1266+
lanes provide scale values:
1267+
- Block size 32: For tile size 16x16x128, each matrix gets 64 scales stored in half
1268+
a VGPR, with `a_first_scale_lane`/`b_first_scale_lane` selecting lanes 0-15 (index=0) or
1269+
16-31 (index=1). For a tile size of 32x16x128, matrix A gets 128 scales in
1270+
a full VGPR (`a_first_scale_lane` is unused), while matrix B gets 64 scales in
1271+
half a VGPR.
1272+
1273+
- Block size 16: For a tile size of 16x16x128, each matrix gets
1274+
128 scales stored in half of two VGPRs, with `a_first_scale_lane`/`b_first_scale_lane`
1275+
selecting lanes 0-15 (index=0) or 16-31 (index=1) for each of the VGPRs.
1276+
For 32x16x128, matrix A gets 256 scales in two VGPRs (`a_first_scale_lane` is unused),
1277+
while matrix B gets 128 scales stored in half of two VGPRs.
1278+
1279+
Example:
1280+
```mlir
1281+
// 16x16x128: fp8 inputs
1282+
%0 = amdgpu.scaled_wmma 16x16x128 (%scaleVecA * %matA) * (%scaleVecB * %matB) + %matC
1283+
{a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32}
1284+
: vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>,
1285+
vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
1286+
1287+
// 32x16x128: fp4 inputs with different scale indices
1288+
%1 = amdgpu.scaled_wmma 32x16x128 (%scaleVecD * %matD) * (%scaleVecE * %matE) + %matF
1289+
{a_first_scale_lane = 0 : i32, b_first_scale_lane = 1 : i32}
1290+
: vector<8xf8E4M3FN>, vector<128xf4E2M1FN>,
1291+
vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
1292+
```
1293+
}];
1294+
let assemblyFormat = [{
1295+
custom<MNKDimensionList>($m, $n, $k) ` `
1296+
`(` $scaleA `*` $sourceA `)` `*`
1297+
`(` $scaleB `*` $sourceB `)` `+` $destC
1298+
attr-dict
1299+
`:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
1300+
}];
1301+
let hasVerifier = 1;
1302+
}
1303+
12291304
def AMDGPU_MakeDmaBaseOp :
12301305
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
12311306
Arguments<(ins Arg<AnyMemRef>:$global,

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 167 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
612612

613613
} // namespace
614614

615-
/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
616-
/// and LLVM AMDGPU intrinsics convention.
615+
/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
616+
/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
617617
///
618618
/// Specifically:
619619
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -627,9 +627,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
627627
/// Note that the type of `input` has already been LLVM type converted:
628628
/// therefore 8-bit and smaller floats are represented as their corresponding
629629
/// `iN` integers.
630-
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
631-
Location loc, Value input,
632-
bool allowBf16 = true) {
630+
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
631+
Location loc, Value input,
632+
bool allowBf16 = true) {
633633
Type inputType = input.getType();
634634
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
635635
if (vectorType.getElementType().isBF16() && !allowBf16)
@@ -653,23 +653,59 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
653653
return input;
654654
}
655655

656-
/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
657-
/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
656+
/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
657+
/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
658658
///
659659
/// Specifically:
660660
/// 1. If `input` is a i8 value, zero extend it to i32
661-
/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
661+
/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
662662
///
663663
/// Note that the type of `input` has already been LLVM type converted:
664664
/// therefore 8-bit and smaller floats are represented as their corresponding
665665
/// `iN` integers.
666-
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
667-
Location loc, Value input) {
666+
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
667+
Value input) {
668668
Type inputType = input.getType();
669-
Type outputType = rewriter.getI32Type();
669+
670+
// Handle scalar i8: zero extend to i32.
670671
if (auto intType = dyn_cast<IntegerType>(inputType))
671-
return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
672-
return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
672+
return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(), input);
673+
674+
// Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
675+
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
676+
int64_t numElements = vectorType.getNumElements();
677+
assert((numElements == 4 || numElements == 8) &&
678+
"scale operand must be a vector of length 4 or 8");
679+
IntegerType outputType =
680+
(numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
681+
return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
682+
}
683+
684+
llvm_unreachable("unexpected input type for scale operand");
685+
}
686+
687+
/// Maps f8 scale element types to WMMA scale format codes.
688+
static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
689+
return TypeSwitch<Type, std::optional<uint32_t>>(elemType)
690+
.Case([](Float8E8M0FNUType) { return 0; })
691+
.Case([](Float8E4M3FNType) { return 2; })
692+
.Default(std::nullopt);
693+
}
694+
695+
/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
696+
/// and scale block size (16 or 32).
697+
static std::optional<StringRef>
698+
getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16) {
699+
if (m == 16 && n == 16 && k == 128)
700+
return isScale16
701+
? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
702+
: ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
703+
704+
if (m == 32 && n == 16 && k == 128)
705+
return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
706+
: ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
707+
708+
return std::nullopt;
673709
}
674710

675711
/// Push an input operand. If it is a float type, nothing to do. If it is
@@ -918,7 +954,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
918954
return std::nullopt;
919955
}
920956

921-
static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
957+
static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
922958
return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
923959
.Case([](Float8E4M3FNType) { return 0u; })
924960
.Case([](Float8E5M2Type) { return 1u; })
@@ -947,8 +983,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
947983
if (!isa<Float32Type>(destType))
948984
return std::nullopt;
949985

950-
std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
951-
std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
986+
std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
987+
std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
952988
if (!aTypeCode || !bTypeCode)
953989
return std::nullopt;
954990

@@ -1212,9 +1248,9 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
12121248
}();
12131249
OperationState loweredOp(loc, intrinsicName);
12141250
loweredOp.addTypes(intrinsicOutType);
1215-
loweredOp.addOperands({convertMFMAVectorOperand(
1251+
loweredOp.addOperands({packSmallFloatVectorOperand(
12161252
rewriter, loc, adaptor.getSourceA(), allowBf16),
1217-
convertMFMAVectorOperand(
1253+
packSmallFloatVectorOperand(
12181254
rewriter, loc, adaptor.getSourceB(), allowBf16),
12191255
adaptor.getDestC()});
12201256
if (isScaled) {
@@ -1261,8 +1297,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
12611297
OperationState loweredOp(loc, intrinsicName);
12621298
loweredOp.addTypes(intrinsicOutType);
12631299
loweredOp.addOperands(
1264-
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1265-
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1300+
{packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1301+
packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
12661302
adaptor.getDestC()});
12671303
Value scalesIdxA =
12681304
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
@@ -1273,10 +1309,10 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
12731309
createI32Constant(rewriter, loc, bTypeCode),
12741310
/*scales idx A=*/scalesIdxA,
12751311
/*scales A*/
1276-
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1312+
castScaleOperand(rewriter, loc, adaptor.getScalesA()),
12771313
/*scales idx B=*/scalesIdxB,
12781314
/*scales B*/
1279-
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1315+
castScaleOperand(rewriter, loc, adaptor.getScalesB())});
12801316
Value lowered = rewriter.create(loweredOp)->getResult(0);
12811317
rewriter.replaceOp(op, lowered);
12821318
return success();
@@ -1363,6 +1399,110 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13631399
}
13641400
};
13651401

1402+
struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
1403+
ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1404+
: ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1405+
1406+
Chipset chipset;
1407+
1408+
LogicalResult
1409+
matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1410+
ConversionPatternRewriter &rewriter) const override {
1411+
Location loc = op.getLoc();
1412+
auto outType =
1413+
typeConverter->convertType<VectorType>(op.getDestD().getType());
1414+
if (!outType)
1415+
return rewriter.notifyMatchFailure(op, "type conversion failed");
1416+
1417+
if (chipset < Chipset(12, 5, 0))
1418+
return op->emitOpError("WMMA scale only supported on gfx1250+");
1419+
1420+
int64_t m = op.getM();
1421+
int64_t n = op.getN();
1422+
int64_t k = op.getK();
1423+
1424+
Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
1425+
Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
1426+
1427+
std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
1428+
std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
1429+
1430+
if (!aFmtCode || !bFmtCode)
1431+
return op.emitOpError("unsupported element types for scaled_wmma");
1432+
1433+
// Get scale vector types and determine variant (scale vs scale16).
1434+
auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1435+
auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1436+
1437+
if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1438+
return op.emitOpError("scaleA and scaleB must have equal vector length");
1439+
1440+
// Extract scale format from element types.
1441+
Type scaleAElemType = scaleAVecType.getElementType();
1442+
Type scaleBElemType = scaleBVecType.getElementType();
1443+
1444+
std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
1445+
std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
1446+
1447+
if (!scaleAFmt || !scaleBFmt)
1448+
return op.emitOpError("unsupported scale element types");
1449+
1450+
// Determine which intrinsic to use based on dimensions.
1451+
bool isScale16 = (scaleAVecType.getNumElements() == 8);
1452+
std::optional<StringRef> intrinsicName =
1453+
getScaledWmmaIntrinsicName(m, n, k, isScale16);
1454+
if (!intrinsicName)
1455+
return op.emitOpError("unsupported scaled_wmma dimensions: ")
1456+
<< m << "x" << n << "x" << k;
1457+
1458+
SmallVector<NamedAttribute, 8> attrs;
1459+
1460+
// The f4 variant does not have fmtA and fmtB attributes.
1461+
bool is32x16 = (m == 32 && n == 16 && k == 128);
1462+
if (!is32x16) {
1463+
attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1464+
attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1465+
}
1466+
1467+
// modC uses default value of 0.
1468+
attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
1469+
1470+
// Scale attributes.
1471+
attrs.emplace_back("scaleAType",
1472+
rewriter.getI32IntegerAttr(op.getAFirstScaleLane()));
1473+
attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1474+
attrs.emplace_back("scaleBType",
1475+
rewriter.getI32IntegerAttr(op.getBFirstScaleLane()));
1476+
attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1477+
1478+
// Reuse flags use default value of false.
1479+
attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
1480+
attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
1481+
1482+
// Convert typed float vectors to packed format.
1483+
Value sourceA =
1484+
packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
1485+
Value sourceB =
1486+
packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
1487+
1488+
// Pack scale vectors into i32/i64.
1489+
Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
1490+
Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
1491+
1492+
// Create the intrinsic call.
1493+
OperationState loweredOp(loc, *intrinsicName);
1494+
loweredOp.addTypes(outType);
1495+
loweredOp.addOperands(
1496+
{sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1497+
loweredOp.addAttributes(attrs);
1498+
1499+
Operation *lowered = rewriter.create(loweredOp);
1500+
rewriter.replaceOp(op, lowered->getResults());
1501+
1502+
return success();
1503+
}
1504+
};
1505+
13661506
struct TransposeLoadOpLowering
13671507
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
13681508
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
@@ -2408,10 +2548,11 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
24082548
ROCDL::RawPtrBufferAtomicCmpSwap>,
24092549
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
24102550
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2411-
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
2412-
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2413-
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2414-
GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
2551+
WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
2552+
ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
2553+
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
2554+
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
2555+
TransposeLoadOpLowering, AMDGPUPermlaneLowering,
24152556
AMDGPUMakeDmaBaseLowering>(converter, chipset);
24162557
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
24172558
}

0 commit comments

Comments
 (0)