@@ -619,8 +619,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
619619
620620} // namespace
621621
622- // / Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
623- // / and LLVM AMDGPU intrinsics convention .
622+ // / Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
623+ // / expected by scaled matrix multiply intrinsics (MFMA/WMMA) .
624624// /
625625// / Specifically:
626626// / 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -634,9 +634,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
634634// / Note that the type of `input` has already been LLVM type converted:
635635// / therefore 8-bit and smaller floats are represented as their corresponding
636636// / `iN` integers.
637- static Value convertMFMAVectorOperand (ConversionPatternRewriter &rewriter,
638- Location loc, Value input,
639- bool allowBf16 = true ) {
637+ static Value packSmallFloatVectorOperand (ConversionPatternRewriter &rewriter,
638+ Location loc, Value input,
639+ bool allowBf16 = true ) {
640640 Type inputType = input.getType ();
641641 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
642642 if (vectorType.getElementType ().isBF16 () && !allowBf16)
@@ -660,23 +660,60 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
660660 return input;
661661}
662662
663- // / Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
664- // / dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
663+ // / Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
664+ // / AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
665665// /
666666// / Specifically:
667667// / 1. If `input` is a i8 value, zero extend it to i32
668- // / 2. If `input` is a vector of length 4 and type i8, cast it to i32
668+ // / 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
669669// /
670670// / Note that the type of `input` has already been LLVM type converted:
671671// / therefore 8-bit and smaller floats are represented as their corresponding
672672// / `iN` integers.
673- static Value castMFMAScaleOperand (ConversionPatternRewriter &rewriter,
674- Location loc, Value input) {
675- Type inputType = input.getType ();
676- Type outputType = rewriter.getI32Type ();
677- if (auto intType = dyn_cast<IntegerType>(inputType))
678- return LLVM::ZExtOp::create (rewriter, loc, outputType, input);
679- return LLVM::BitcastOp::create (rewriter, loc, outputType, input);
673+ static Value castScaleOperand (ConversionPatternRewriter &rewriter, Location loc,
674+ Value input) {
675+ return TypeSwitch<Type, Value>(input.getType ())
676+ .Case <IntegerType>([&](IntegerType) {
677+ // Handle scalar i8: zero extend to i32.
678+ return LLVM::ZExtOp::create (rewriter, loc, rewriter.getI32Type (),
679+ input);
680+ })
681+ .Case <VectorType>([&](VectorType vectorType) {
682+ // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
683+ int64_t numElements = vectorType.getNumElements ();
684+ assert ((numElements == 4 || numElements == 8 ) &&
685+ " scale operand must be a vector of length 4 or 8" );
686+ IntegerType outputType =
687+ (numElements == 4 ) ? rewriter.getI32Type () : rewriter.getI64Type ();
688+ return LLVM::BitcastOp::create (rewriter, loc, outputType, input);
689+ })
690+ .Default ([](Type) -> Value {
691+ llvm_unreachable (" unexpected input type for scale operand" );
692+ });
693+ }
694+
695+ // / Maps f8 scale element types to WMMA scale format codes.
696+ static std::optional<uint32_t > getWmmaScaleFormat (Type elemType) {
697+ return TypeSwitch<Type, std::optional<uint32_t >>(elemType)
698+ .Case ([](Float8E8M0FNUType) { return 0 ; })
699+ .Case ([](Float8E4M3FNType) { return 2 ; })
700+ .Default (std::nullopt );
701+ }
702+
703+ // / Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
704+ // / and scale block size (16 or 32).
705+ static std::optional<StringRef>
706+ getScaledWmmaIntrinsicName (int64_t m, int64_t n, int64_t k, bool isScale16) {
707+ if (m == 16 && n == 16 && k == 128 )
708+ return isScale16
709+ ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName ()
710+ : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName ();
711+
712+ if (m == 32 && n == 16 && k == 128 )
713+ return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName ()
714+ : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName ();
715+
716+ return std::nullopt ;
680717}
681718
682719// / Push an input operand. If it is a float type, nothing to do. If it is
@@ -925,7 +962,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
925962 return std::nullopt ;
926963}
927964
928- static std::optional<uint32_t > mfmaTypeSelectCode (Type mlirElemType) {
965+ static std::optional<uint32_t > smallFloatTypeToFormatCode (Type mlirElemType) {
929966 return llvm::TypeSwitch<Type, std::optional<uint32_t >>(mlirElemType)
930967 .Case ([](Float8E4M3FNType) { return 0u ; })
931968 .Case ([](Float8E5M2Type) { return 1u ; })
@@ -954,8 +991,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
954991 if (!isa<Float32Type>(destType))
955992 return std::nullopt ;
956993
957- std::optional<uint32_t > aTypeCode = mfmaTypeSelectCode (aType);
958- std::optional<uint32_t > bTypeCode = mfmaTypeSelectCode (bType);
994+ std::optional<uint32_t > aTypeCode = smallFloatTypeToFormatCode (aType);
995+ std::optional<uint32_t > bTypeCode = smallFloatTypeToFormatCode (bType);
959996 if (!aTypeCode || !bTypeCode)
960997 return std::nullopt ;
961998
@@ -1219,9 +1256,9 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
12191256 }();
12201257 OperationState loweredOp (loc, intrinsicName);
12211258 loweredOp.addTypes (intrinsicOutType);
1222- loweredOp.addOperands ({convertMFMAVectorOperand (
1259+ loweredOp.addOperands ({packSmallFloatVectorOperand (
12231260 rewriter, loc, adaptor.getSourceA (), allowBf16),
1224- convertMFMAVectorOperand (
1261+ packSmallFloatVectorOperand (
12251262 rewriter, loc, adaptor.getSourceB (), allowBf16),
12261263 adaptor.getDestC ()});
12271264 if (isScaled) {
@@ -1268,8 +1305,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
12681305 OperationState loweredOp (loc, intrinsicName);
12691306 loweredOp.addTypes (intrinsicOutType);
12701307 loweredOp.addOperands (
1271- {convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceA ()),
1272- convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceB ()),
1308+ {packSmallFloatVectorOperand (rewriter, loc, adaptor.getSourceA ()),
1309+ packSmallFloatVectorOperand (rewriter, loc, adaptor.getSourceB ()),
12731310 adaptor.getDestC ()});
12741311 Value scalesIdxA =
12751312 createI32Constant (rewriter, loc, adaptor.getScalesIdxA ());
@@ -1280,10 +1317,10 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
12801317 createI32Constant (rewriter, loc, bTypeCode),
12811318 /* scales idx A=*/ scalesIdxA,
12821319 /* scales A*/
1283- castMFMAScaleOperand (rewriter, loc, adaptor.getScalesA ()),
1320+ castScaleOperand (rewriter, loc, adaptor.getScalesA ()),
12841321 /* scales idx B=*/ scalesIdxB,
12851322 /* scales B*/
1286- castMFMAScaleOperand (rewriter, loc, adaptor.getScalesB ())});
1323+ castScaleOperand (rewriter, loc, adaptor.getScalesB ())});
12871324 Value lowered = rewriter.create (loweredOp)->getResult (0 );
12881325 rewriter.replaceOp (op, lowered);
12891326 return success ();
@@ -1370,6 +1407,111 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13701407 }
13711408};
13721409
1410+ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern <ScaledWMMAOp> {
1411+ ScaledWMMAOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
1412+ : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1413+
1414+ Chipset chipset;
1415+
1416+ LogicalResult
1417+ matchAndRewrite (ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1418+ ConversionPatternRewriter &rewriter) const override {
1419+ Location loc = op.getLoc ();
1420+ auto outType =
1421+ typeConverter->convertType <VectorType>(op.getDestD ().getType ());
1422+ if (!outType)
1423+ return rewriter.notifyMatchFailure (op, " type conversion failed" );
1424+
1425+ if (chipset < kGfx1250 )
1426+ return op->emitOpError (" WMMA scale only supported on gfx1250+" );
1427+
1428+ int64_t m = op.getM ();
1429+ int64_t n = op.getN ();
1430+ int64_t k = op.getK ();
1431+
1432+ Type aElemType = getElementTypeOrSelf (op.getSourceA ().getType ());
1433+ Type bElemType = getElementTypeOrSelf (op.getSourceB ().getType ());
1434+
1435+ std::optional<uint32_t > aFmtCode = smallFloatTypeToFormatCode (aElemType);
1436+ std::optional<uint32_t > bFmtCode = smallFloatTypeToFormatCode (bElemType);
1437+
1438+ if (!aFmtCode || !bFmtCode)
1439+ return op.emitOpError (" unsupported element types for scaled_wmma" );
1440+
1441+ // Get scale vector types and determine variant (scale vs scale16).
1442+ auto scaleAVecType = cast<VectorType>(op.getScaleA ().getType ());
1443+ auto scaleBVecType = cast<VectorType>(op.getScaleB ().getType ());
1444+
1445+ if (scaleAVecType.getNumElements () != scaleBVecType.getNumElements ())
1446+ return op.emitOpError (" scaleA and scaleB must have equal vector length" );
1447+
1448+ // Extract scale format from element types.
1449+ Type scaleAElemType = scaleAVecType.getElementType ();
1450+ Type scaleBElemType = scaleBVecType.getElementType ();
1451+
1452+ std::optional<uint32_t > scaleAFmt = getWmmaScaleFormat (scaleAElemType);
1453+ std::optional<uint32_t > scaleBFmt = getWmmaScaleFormat (scaleBElemType);
1454+
1455+ if (!scaleAFmt || !scaleBFmt)
1456+ return op.emitOpError (" unsupported scale element types" );
1457+
1458+ // Determine which intrinsic to use based on dimensions.
1459+ bool isScale16 = (scaleAVecType.getNumElements () == 8 );
1460+ std::optional<StringRef> intrinsicName =
1461+ getScaledWmmaIntrinsicName (m, n, k, isScale16);
1462+ if (!intrinsicName)
1463+ return op.emitOpError (" unsupported scaled_wmma dimensions: " )
1464+ << m << " x" << n << " x" << k;
1465+
1466+ SmallVector<NamedAttribute, 8 > attrs;
1467+
1468+ // The f4 variant does not have fmtA and fmtB attributes.
1469+ bool is32x16 = (m == 32 && n == 16 && k == 128 );
1470+ if (!is32x16) {
1471+ attrs.emplace_back (" fmtA" , rewriter.getI32IntegerAttr (*aFmtCode));
1472+ attrs.emplace_back (" fmtB" , rewriter.getI32IntegerAttr (*bFmtCode));
1473+ }
1474+
1475+ // modC uses default value of 0.
1476+ attrs.emplace_back (" modC" , rewriter.getI16IntegerAttr (0 ));
1477+
1478+ // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
1479+ // half of the wave that is being selected (0 or 1).
1480+ attrs.emplace_back (
1481+ " scaleAType" , rewriter.getI32IntegerAttr (op.getAFirstScaleLane () / 16 ));
1482+ attrs.emplace_back (" fmtScaleA" , rewriter.getI32IntegerAttr (*scaleAFmt));
1483+ attrs.emplace_back (
1484+ " scaleBType" , rewriter.getI32IntegerAttr (op.getBFirstScaleLane () / 16 ));
1485+ attrs.emplace_back (" fmtScaleB" , rewriter.getI32IntegerAttr (*scaleBFmt));
1486+
1487+ // Reuse flags use default value of false.
1488+ attrs.emplace_back (" reuseA" , rewriter.getBoolAttr (false ));
1489+ attrs.emplace_back (" reuseB" , rewriter.getBoolAttr (false ));
1490+
1491+ // Convert typed float vectors to packed format.
1492+ Value sourceA =
1493+ packSmallFloatVectorOperand (rewriter, loc, adaptor.getSourceA ());
1494+ Value sourceB =
1495+ packSmallFloatVectorOperand (rewriter, loc, adaptor.getSourceB ());
1496+
1497+ // Pack scale vectors into i32/i64.
1498+ Value packedScaleA = castScaleOperand (rewriter, loc, adaptor.getScaleA ());
1499+ Value packedScaleB = castScaleOperand (rewriter, loc, adaptor.getScaleB ());
1500+
1501+ // Create the intrinsic call.
1502+ OperationState loweredOp (loc, *intrinsicName);
1503+ loweredOp.addTypes (outType);
1504+ loweredOp.addOperands (
1505+ {sourceA, sourceB, adaptor.getDestC (), packedScaleA, packedScaleB});
1506+ loweredOp.addAttributes (attrs);
1507+
1508+ Operation *lowered = rewriter.create (loweredOp);
1509+ rewriter.replaceOp (op, lowered->getResults ());
1510+
1511+ return success ();
1512+ }
1513+ };
1514+
13731515struct TransposeLoadOpLowering
13741516 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
13751517 TransposeLoadOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
@@ -2780,11 +2922,11 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
27802922 ROCDL::RawPtrBufferAtomicCmpSwap>,
27812923 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
27822924 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2783- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering ,
2784- ScaledExtPackedOpLowering, PackedScaledTruncOpLowering ,
2785- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering ,
2786- GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering ,
2787- AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter ,
2788- chipset);
2925+ WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering ,
2926+ ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering ,
2927+ PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering ,
2928+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering ,
2929+ TransposeLoadOpLowering, AMDGPUPermlaneLowering,AMDGPUMakeDmaBaseLowering ,
2930+ AMDGPUMakeDmaDescriptorLowering>(converter, chipset);
27892931 patterns.add <AMDGPUSwizzleBitModeLowering>(converter);
27902932}
0 commit comments