2222#include " ../LLVMCommon/MemRefDescriptor.h"
2323
2424#include " llvm/ADT/STLExtras.h"
25+ #include " llvm/ADT/TypeSwitch.h"
2526#include < optional>
2627
2728namespace mlir {
@@ -36,6 +37,7 @@ using namespace mlir::amdgpu;
3637constexpr Chipset kGfx908 = Chipset(9 , 0 , 8 );
3738constexpr Chipset kGfx90a = Chipset(9 , 0 , 0xa );
3839constexpr Chipset kGfx942 = Chipset(9 , 4 , 2 );
40+ constexpr Chipset kGfx950 = Chipset(9 , 5 , 0 );
3941
4042// / Convert an unsigned number `val` to i32.
4143static Value convertUnsignedToI32 (ConversionPatternRewriter &rewriter,
@@ -494,18 +496,33 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
494496// / and LLVM AMDGPU intrinsics convention.
495497// /
496498// / Specifically:
497- // / 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
498- // / 2. If the element type is bfloat16, bitcast it to i16.
499+ // / 1. If the element type is bfloat16, bitcast it to i16.
500+ // / 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
501+ // / instead, which is what the f8f6f4 intrinsics use.
502+ // / 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
503+ // / integer.
504+ // /
505+ // / Note that the type of `input` has already been LLVM type converted:
506+ // / therefore 8-bit and smaller floats are represented as their corresponding
507+ // / `iN` integers.
499508static Value convertMFMAVectorOperand (ConversionPatternRewriter &rewriter,
500509 Location loc, Value input) {
501510 Type inputType = input.getType ();
502511 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
503512 if (vectorType.getElementType ().isBF16 ())
504513 return rewriter.create <LLVM::BitcastOp>(
505514 loc, vectorType.clone (rewriter.getI16Type ()), input);
506- if (vectorType.getElementType ().isInteger (8 )) {
515+ if (vectorType.getElementType ().isInteger (8 ) &&
516+ vectorType.getNumElements () <= 8 )
507517 return rewriter.create <LLVM::BitcastOp>(
508518 loc, rewriter.getIntegerType (vectorType.getNumElements () * 8 ), input);
519+ if (isa<IntegerType>(vectorType.getElementType ()) &&
520+ vectorType.getElementTypeBitWidth () <= 8 ) {
521+ int64_t numWords = llvm::divideCeil (
522+ vectorType.getNumElements () * vectorType.getElementTypeBitWidth (),
523+ 32 );
524+ return rewriter.create <LLVM::BitcastOp>(
525+ loc, VectorType::get (numWords, rewriter.getI32Type ()), input);
509526 }
510527 }
511528 return input;
@@ -622,12 +639,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
622639 Chipset chipset) {
623640 uint32_t m = mfma.getM (), n = mfma.getN (), k = mfma.getK (),
624641 b = mfma.getBlocks ();
625- Type sourceElem = mfma.getSourceA ().getType ();
626- if (auto sourceType = dyn_cast<VectorType>(sourceElem))
627- sourceElem = sourceType.getElementType ();
628- Type destElem = mfma.getDestC ().getType ();
629- if (auto destType = dyn_cast<VectorType>(destElem))
630- destElem = destType.getElementType ();
642+ Type sourceElem = getElementTypeOrSelf (mfma.getSourceA ().getType ());
643+ Type destElem = getElementTypeOrSelf (mfma.getDestC ().getType ());
631644
632645 if (sourceElem.isF32 () && destElem.isF32 ()) {
633646 if (mfma.getReducePrecision () && chipset >= kGfx942 ) {
@@ -649,6 +662,12 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
649662 }
650663
651664 if (sourceElem.isF16 () && destElem.isF32 ()) {
665+ if (chipset >= kGfx950 ) {
666+ if (m == 32 && n == 32 && k == 16 && b == 1 )
667+ return ROCDL::mfma_f32_32x32x16_f16::getOperationName ();
668+ if (m == 16 && n == 16 && k == 32 && b == 1 )
669+ return ROCDL::mfma_f32_16x16x32_f16::getOperationName ();
670+ }
652671 if (m == 32 && n == 32 && k == 4 && b == 2 )
653672 return ROCDL::mfma_f32_32x32x4f16::getOperationName ();
654673 if (m == 16 && n == 16 && k == 4 && b == 4 )
@@ -661,20 +680,25 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
661680 return ROCDL::mfma_f32_16x16x16f16::getOperationName ();
662681 }
663682
664- if (sourceElem.isBF16 () && destElem.isF32 () && chipset >= kGfx90a ) {
665- if (m == 32 && n == 32 && k == 4 && b == 2 )
666- return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName ();
667- if (m == 16 && n == 16 && k == 4 && b == 4 )
668- return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName ();
669- if (m == 4 && n == 4 && k == 4 && b == 16 )
670- return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName ();
671- if (m == 32 && n == 32 && k == 8 && b == 1 )
672- return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName ();
673- if (m == 16 && n == 16 && k == 16 && b == 1 )
674- return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName ();
675- }
676-
677683 if (sourceElem.isBF16 () && destElem.isF32 ()) {
684+ if (chipset >= kGfx950 ) {
685+ if (m == 32 && n == 32 && k == 16 && b == 1 )
686+ return ROCDL::mfma_f32_32x32x16_bf16::getOperationName ();
687+ if (m == 16 && n == 16 && k == 32 && b == 1 )
688+ return ROCDL::mfma_f32_16x16x32_bf16::getOperationName ();
689+ }
690+ if (chipset >= kGfx90a ) {
691+ if (m == 32 && n == 32 && k == 4 && b == 2 )
692+ return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName ();
693+ if (m == 16 && n == 16 && k == 4 && b == 4 )
694+ return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName ();
695+ if (m == 4 && n == 4 && k == 4 && b == 16 )
696+ return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName ();
697+ if (m == 32 && n == 32 && k == 8 && b == 1 )
698+ return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName ();
699+ if (m == 16 && n == 16 && k == 16 && b == 1 )
700+ return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName ();
701+ }
678702 if (m == 32 && n == 32 && k == 2 && b == 2 )
679703 return ROCDL::mfma_f32_32x32x2bf16::getOperationName ();
680704 if (m == 16 && n == 16 && k == 2 && b == 4 )
@@ -687,7 +711,13 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
687711 return ROCDL::mfma_f32_16x16x8bf16::getOperationName ();
688712 }
689713
690- if (isa<IntegerType>(sourceElem) && destElem.isInteger (32 )) {
714+ if (sourceElem.isInteger (8 ) && destElem.isInteger (32 )) {
715+ if (chipset >= kGfx950 ) {
716+ if (m == 32 && n == 32 && k == 32 && b == 1 )
717+ return ROCDL::mfma_i32_32x32x32_i8::getOperationName ();
718+ if (m == 16 && n == 16 && k == 64 && b == 1 )
719+ return ROCDL::mfma_i32_16x16x64_i8::getOperationName ();
720+ }
691721 if (m == 32 && n == 32 && k == 4 && b == 2 )
692722 return ROCDL::mfma_i32_32x32x4i8::getOperationName ();
693723 if (m == 16 && n == 16 && k == 4 && b == 4 )
@@ -750,6 +780,59 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
750780 return std::nullopt ;
751781}
752782
783+ static std::optional<uint32_t > mfmaTypeSelectCode (Type mlirElemType) {
784+ return llvm::TypeSwitch<Type, std::optional<uint32_t >>(mlirElemType)
785+ .Case ([](Float8E4M3FNType) { return 0u ; })
786+ .Case ([](Float8E5M2Type) { return 1u ; })
787+ .Case ([](Float6E2M3FNType) { return 2u ; })
788+ .Case ([](Float6E3M2FNType) { return 3u ; })
789+ .Case ([](Float4E2M1FNType) { return 4u ; })
790+ .Default ([](Type) { return std::nullopt ; });
791+ }
792+
793+ // / If there is a scaled MFMA instruction for the input element types `aType`
794+ // / and `bType`, output type `destType`, problem size M, N, K, and B (number of
795+ // / blocks) on the given `chipset`, return a tuple consisting of the
796+ // / OperationName of the intrinsic and the type codes that need to be passed to
797+ // / that intrinsic. Note that this is also used to implement some un-scaled
798+ // / MFMAs, since the compiler represents the ordinary instruction as a "scaled"
799+ // / MFMA with a scale of 0.
800+ static std::optional<std::tuple<StringRef, uint32_t , uint32_t >>
801+ mfmaOpToScaledIntrinsic (Type aType, Type bType, Type destType, uint32_t m,
802+ uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
803+ aType = getElementTypeOrSelf (aType);
804+ bType = getElementTypeOrSelf (bType);
805+ destType = getElementTypeOrSelf (destType);
806+
807+ if (chipset < kGfx950 )
808+ return std::nullopt ;
809+ if (!isa<Float32Type>(destType))
810+ return std::nullopt ;
811+
812+ std::optional<uint32_t > aTypeCode = mfmaTypeSelectCode (aType);
813+ std::optional<uint32_t > bTypeCode = mfmaTypeSelectCode (bType);
814+ if (!aTypeCode || !bTypeCode)
815+ return std::nullopt ;
816+
817+ if (m == 32 && n == 32 && k == 64 && b == 1 )
818+ return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName (),
819+ *aTypeCode, *bTypeCode};
820+ if (m == 16 && n == 16 && k == 128 && b == 1 )
821+ return std::tuple{
822+ ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName (), *aTypeCode,
823+ *bTypeCode};
824+
825+ return std::nullopt ;
826+ }
827+
828+ static std::optional<std::tuple<StringRef, uint32_t , uint32_t >>
829+ mfmaOpToScaledIntrinsic (MFMAOp mfma, Chipset chipset) {
830+ return mfmaOpToScaledIntrinsic (
831+ mfma.getSourceA ().getType (), mfma.getSourceB ().getType (),
832+ mfma.getDestC ().getType (), mfma.getM (), mfma.getN (), mfma.getK (),
833+ mfma.getBlocks (), chipset);
834+ }
835+
753836// / Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
754837// / if one exists. This includes checking to ensure the intrinsic is supported
755838// / on the architecture you are compiling for.
@@ -829,16 +912,40 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
829912 op.getNegateA () | (op.getNegateB () << 1 ) | (op.getNegateC () << 2 );
830913 }
831914 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic (op, chipset);
832- if (!maybeIntrinsic.has_value ())
915+ std::optional<std::tuple<StringRef, uint32_t , uint32_t >>
916+ maybeScaledIntrinsic = mfmaOpToScaledIntrinsic (op, chipset);
917+ if (!maybeIntrinsic.has_value () && !maybeScaledIntrinsic.has_value ())
833918 return op.emitOpError (" no intrinsic matching MFMA size on given chipset" );
834- OperationState loweredOp (loc, *maybeIntrinsic);
919+
920+ bool isScaled =
921+ !maybeIntrinsic.has_value () && maybeScaledIntrinsic.has_value ();
922+ if (isScaled &&
923+ (adaptor.getAbid () > 0 || getBlgpField > 0 || op.getCbsz () > 0 )) {
924+ return op.emitOpError (
925+ " non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
926+ " be scaled as those fields are used for type information" );
927+ }
928+
929+ StringRef intrinsicName =
930+ isScaled ? std::get<0 >(*maybeScaledIntrinsic) : *maybeIntrinsic;
931+ OperationState loweredOp (loc, intrinsicName);
835932 loweredOp.addTypes (intrinsicOutType);
836933 loweredOp.addOperands (
837934 {convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceA ()),
838935 convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceB ()),
839- adaptor.getDestC (), createI32Constant (rewriter, loc, op.getCbsz ()),
840- createI32Constant (rewriter, loc, op.getAbid ()),
841- createI32Constant (rewriter, loc, getBlgpField)});
936+ adaptor.getDestC ()});
937+ if (isScaled) {
938+ Value zero = createI32Constant (rewriter, loc, 0 );
939+ auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
940+ loweredOp.addOperands ({createI32Constant (rewriter, loc, aTypeCode),
941+ createI32Constant (rewriter, loc, bTypeCode),
942+ /* scale A byte=*/ zero, /* scale A=*/ zero,
943+ /* scale B byte=*/ zero, /* scale B=*/ zero});
944+ } else {
945+ loweredOp.addOperands ({createI32Constant (rewriter, loc, op.getCbsz ()),
946+ createI32Constant (rewriter, loc, op.getAbid ()),
947+ createI32Constant (rewriter, loc, getBlgpField)});
948+ };
842949 Value lowered = rewriter.create (loweredOp)->getResult (0 );
843950 if (outType != intrinsicOutType)
844951 lowered = rewriter.create <LLVM::BitcastOp>(loc, outType, lowered);
0 commit comments