@@ -612,8 +612,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
612612
613613} // namespace
614614
615- // / Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
616- // / and LLVM AMDGPU intrinsics convention .
615+ // / Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
616+ // / expected by scaled matrix multiply intrinsics (MFMA/WMMA) .
617617// /
618618// / Specifically:
619619// / 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -627,9 +627,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
627627// / Note that the type of `input` has already been LLVM type converted:
628628// / therefore 8-bit and smaller floats are represented as their corresponding
629629// / `iN` integers.
630- static Value convertMFMAVectorOperand (ConversionPatternRewriter &rewriter,
631- Location loc, Value input,
632- bool allowBf16 = true ) {
630+ static Value packSmallFloatVectorOperand (ConversionPatternRewriter &rewriter,
631+ Location loc, Value input,
632+ bool allowBf16 = true ) {
633633 Type inputType = input.getType ();
634634 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
635635 if (vectorType.getElementType ().isBF16 () && !allowBf16)
@@ -653,23 +653,59 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
653653 return input;
654654}
655655
656- // / Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
657- // / dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
656+ // / Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
657+ // / AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
658658// /
659659// / Specifically:
660660// / 1. If `input` is a i8 value, zero extend it to i32
661- // / 2. If `input` is a vector of length 4 and type i8, cast it to i32
661+ // / 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
662662// /
663663// / Note that the type of `input` has already been LLVM type converted:
664664// / therefore 8-bit and smaller floats are represented as their corresponding
665665// / `iN` integers.
666- static Value castMFMAScaleOperand (ConversionPatternRewriter &rewriter,
667- Location loc, Value input) {
666+ static Value castScaleOperand (ConversionPatternRewriter &rewriter, Location loc ,
667+ Value input) {
668668 Type inputType = input.getType ();
669- Type outputType = rewriter.getI32Type ();
669+
670+ // Handle scalar i8: zero extend to i32.
670671 if (auto intType = dyn_cast<IntegerType>(inputType))
671- return LLVM::ZExtOp::create (rewriter, loc, outputType, input);
672- return LLVM::BitcastOp::create (rewriter, loc, outputType, input);
672+ return LLVM::ZExtOp::create (rewriter, loc, rewriter.getI32Type (), input);
673+
674+ // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
675+ if (auto vectorType = dyn_cast<VectorType>(inputType)) {
676+ int64_t numElements = vectorType.getNumElements ();
677+ assert ((numElements == 4 || numElements == 8 ) &&
678+ " scale operand must be a vector of length 4 or 8" );
679+ IntegerType outputType =
680+ (numElements == 4 ) ? rewriter.getI32Type () : rewriter.getI64Type ();
681+ return LLVM::BitcastOp::create (rewriter, loc, outputType, input);
682+ }
683+
684+ llvm_unreachable (" unexpected input type for scale operand" );
685+ }
686+
687+ // / Maps f8 scale element types to WMMA scale format codes.
688+ static std::optional<uint32_t > getWmmaScaleFormat (Type elemType) {
689+ return TypeSwitch<Type, std::optional<uint32_t >>(elemType)
690+ .Case ([](Float8E8M0FNUType) { return 0 ; })
691+ .Case ([](Float8E4M3FNType) { return 2 ; })
692+ .Default (std::nullopt );
693+ }
694+
695+ // / Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
696+ // / and scale block size (16 or 32).
697+ static std::optional<StringRef>
698+ getScaledWmmaIntrinsicName (int64_t m, int64_t n, int64_t k, bool isScale16) {
699+ if (m == 16 && n == 16 && k == 128 )
700+ return isScale16
701+ ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName ()
702+ : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName ();
703+
704+ if (m == 32 && n == 16 && k == 128 )
705+ return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName ()
706+ : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName ();
707+
708+ return std::nullopt ;
673709}
674710
675711// / Push an input operand. If it is a float type, nothing to do. If it is
@@ -918,7 +954,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
918954 return std::nullopt ;
919955}
920956
921- static std::optional<uint32_t > mfmaTypeSelectCode (Type mlirElemType) {
957+ static std::optional<uint32_t > smallFloatTypeToFormatCode (Type mlirElemType) {
922958 return llvm::TypeSwitch<Type, std::optional<uint32_t >>(mlirElemType)
923959 .Case ([](Float8E4M3FNType) { return 0u ; })
924960 .Case ([](Float8E5M2Type) { return 1u ; })
@@ -947,8 +983,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
947983 if (!isa<Float32Type>(destType))
948984 return std::nullopt ;
949985
950- std::optional<uint32_t > aTypeCode = mfmaTypeSelectCode (aType);
951- std::optional<uint32_t > bTypeCode = mfmaTypeSelectCode (bType);
986+ std::optional<uint32_t > aTypeCode = smallFloatTypeToFormatCode (aType);
987+ std::optional<uint32_t > bTypeCode = smallFloatTypeToFormatCode (bType);
952988 if (!aTypeCode || !bTypeCode)
953989 return std::nullopt ;
954990
@@ -1212,9 +1248,9 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
12121248 }();
12131249 OperationState loweredOp (loc, intrinsicName);
12141250 loweredOp.addTypes (intrinsicOutType);
1215- loweredOp.addOperands ({convertMFMAVectorOperand (
1251+ loweredOp.addOperands ({packSmallFloatVectorOperand (
12161252 rewriter, loc, adaptor.getSourceA (), allowBf16),
1217- convertMFMAVectorOperand (
1253+ packSmallFloatVectorOperand (
12181254 rewriter, loc, adaptor.getSourceB (), allowBf16),
12191255 adaptor.getDestC ()});
12201256 if (isScaled) {
@@ -1261,8 +1297,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
12611297 OperationState loweredOp (loc, intrinsicName);
12621298 loweredOp.addTypes (intrinsicOutType);
12631299 loweredOp.addOperands (
1264- {convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceA ()),
1265- convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceB ()),
1300+ {packSmallFloatVectorOperand (rewriter, loc, adaptor.getSourceA ()),
1301+ packSmallFloatVectorOperand (rewriter, loc, adaptor.getSourceB ()),
12661302 adaptor.getDestC ()});
12671303 Value scalesIdxA =
12681304 createI32Constant (rewriter, loc, adaptor.getScalesIdxA ());
@@ -1273,10 +1309,10 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
12731309 createI32Constant (rewriter, loc, bTypeCode),
12741310 /* scales idx A=*/ scalesIdxA,
12751311 /* scales A*/
1276- castMFMAScaleOperand (rewriter, loc, adaptor.getScalesA ()),
1312+ castScaleOperand (rewriter, loc, adaptor.getScalesA ()),
12771313 /* scales idx B=*/ scalesIdxB,
12781314 /* scales B*/
1279- castMFMAScaleOperand (rewriter, loc, adaptor.getScalesB ())});
1315+ castScaleOperand (rewriter, loc, adaptor.getScalesB ())});
12801316 Value lowered = rewriter.create (loweredOp)->getResult (0 );
12811317 rewriter.replaceOp (op, lowered);
12821318 return success ();
@@ -1363,6 +1399,110 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13631399 }
13641400};
13651401
1402+ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern <ScaledWMMAOp> {
1403+ ScaledWMMAOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
1404+ : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1405+
1406+ Chipset chipset;
1407+
1408+ LogicalResult
1409+ matchAndRewrite (ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1410+ ConversionPatternRewriter &rewriter) const override {
1411+ Location loc = op.getLoc ();
1412+ auto outType =
1413+ typeConverter->convertType <VectorType>(op.getDestD ().getType ());
1414+ if (!outType)
1415+ return rewriter.notifyMatchFailure (op, " type conversion failed" );
1416+
1417+ if (chipset < Chipset (12 , 5 , 0 ))
1418+ return op->emitOpError (" WMMA scale only supported on gfx1250+" );
1419+
1420+ int64_t m = op.getM ();
1421+ int64_t n = op.getN ();
1422+ int64_t k = op.getK ();
1423+
1424+ Type aElemType = getElementTypeOrSelf (op.getSourceA ().getType ());
1425+ Type bElemType = getElementTypeOrSelf (op.getSourceB ().getType ());
1426+
1427+ std::optional<uint32_t > aFmtCode = smallFloatTypeToFormatCode (aElemType);
1428+ std::optional<uint32_t > bFmtCode = smallFloatTypeToFormatCode (bElemType);
1429+
1430+ if (!aFmtCode || !bFmtCode)
1431+ return op.emitOpError (" unsupported element types for scaled_wmma" );
1432+
1433+ // Get scale vector types and determine variant (scale vs scale16).
1434+ auto scaleAVecType = cast<VectorType>(op.getScaleA ().getType ());
1435+ auto scaleBVecType = cast<VectorType>(op.getScaleB ().getType ());
1436+
1437+ if (scaleAVecType.getNumElements () != scaleBVecType.getNumElements ())
1438+ return op.emitOpError (" scaleA and scaleB must have equal vector length" );
1439+
1440+ // Extract scale format from element types.
1441+ Type scaleAElemType = scaleAVecType.getElementType ();
1442+ Type scaleBElemType = scaleBVecType.getElementType ();
1443+
1444+ std::optional<uint32_t > scaleAFmt = getWmmaScaleFormat (scaleAElemType);
1445+ std::optional<uint32_t > scaleBFmt = getWmmaScaleFormat (scaleBElemType);
1446+
1447+ if (!scaleAFmt || !scaleBFmt)
1448+ return op.emitOpError (" unsupported scale element types" );
1449+
1450+ // Determine which intrinsic to use based on dimensions.
1451+ bool isScale16 = (scaleAVecType.getNumElements () == 8 );
1452+ std::optional<StringRef> intrinsicName =
1453+ getScaledWmmaIntrinsicName (m, n, k, isScale16);
1454+ if (!intrinsicName)
1455+ return op.emitOpError (" unsupported scaled_wmma dimensions: " )
1456+ << m << " x" << n << " x" << k;
1457+
1458+ SmallVector<NamedAttribute, 8 > attrs;
1459+
1460+ // The f4 variant does not have fmtA and fmtB attributes.
1461+ bool is32x16 = (m == 32 && n == 16 && k == 128 );
1462+ if (!is32x16) {
1463+ attrs.emplace_back (" fmtA" , rewriter.getI32IntegerAttr (*aFmtCode));
1464+ attrs.emplace_back (" fmtB" , rewriter.getI32IntegerAttr (*bFmtCode));
1465+ }
1466+
1467+ // modC uses default value of 0.
1468+ attrs.emplace_back (" modC" , rewriter.getI16IntegerAttr (0 ));
1469+
1470+ // Scale attributes.
1471+ attrs.emplace_back (" scaleAType" ,
1472+ rewriter.getI32IntegerAttr (op.getAFirstScaleLane ()));
1473+ attrs.emplace_back (" fmtScaleA" , rewriter.getI32IntegerAttr (*scaleAFmt));
1474+ attrs.emplace_back (" scaleBType" ,
1475+ rewriter.getI32IntegerAttr (op.getBFirstScaleLane ()));
1476+ attrs.emplace_back (" fmtScaleB" , rewriter.getI32IntegerAttr (*scaleBFmt));
1477+
1478+ // Reuse flags use default value of false.
1479+ attrs.emplace_back (" reuseA" , rewriter.getBoolAttr (false ));
1480+ attrs.emplace_back (" reuseB" , rewriter.getBoolAttr (false ));
1481+
1482+ // Convert typed float vectors to packed format.
1483+ Value sourceA =
1484+ packSmallFloatVectorOperand (rewriter, loc, adaptor.getSourceA ());
1485+ Value sourceB =
1486+ packSmallFloatVectorOperand (rewriter, loc, adaptor.getSourceB ());
1487+
1488+ // Pack scale vectors into i32/i64.
1489+ Value packedScaleA = castScaleOperand (rewriter, loc, adaptor.getScaleA ());
1490+ Value packedScaleB = castScaleOperand (rewriter, loc, adaptor.getScaleB ());
1491+
1492+ // Create the intrinsic call.
1493+ OperationState loweredOp (loc, *intrinsicName);
1494+ loweredOp.addTypes (outType);
1495+ loweredOp.addOperands (
1496+ {sourceA, sourceB, adaptor.getDestC (), packedScaleA, packedScaleB});
1497+ loweredOp.addAttributes (attrs);
1498+
1499+ Operation *lowered = rewriter.create (loweredOp);
1500+ rewriter.replaceOp (op, lowered->getResults ());
1501+
1502+ return success ();
1503+ }
1504+ };
1505+
13661506struct TransposeLoadOpLowering
13671507 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
13681508 TransposeLoadOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
@@ -2408,10 +2548,11 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
24082548 ROCDL::RawPtrBufferAtomicCmpSwap>,
24092549 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
24102550 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2411- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
2412- ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2413- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2414- GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
2551+ WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
2552+ ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
2553+ PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
2554+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
2555+ TransposeLoadOpLowering, AMDGPUPermlaneLowering,
24152556 AMDGPUMakeDmaBaseLowering>(converter, chipset);
24162557 patterns.add <AMDGPUSwizzleBitModeLowering>(converter);
24172558}
0 commit comments