1212#include " mlir/Conversion/LLVMCommon/Pattern.h"
1313#include " mlir/Conversion/LLVMCommon/TypeConverter.h"
1414#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15+ #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
1516#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1617#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
1718#include " mlir/IR/BuiltinTypes.h"
@@ -42,6 +43,11 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
4243}
4344
4445namespace {
46+ // Define commonly used chipsets versions for convenience.
47+ constexpr Chipset kGfx908 = Chipset(9 , 0 , 8 );
48+ constexpr Chipset kGfx90a = Chipset(9 , 0 , 0xa );
49+ constexpr Chipset kGfx940 = Chipset(9 , 4 , 0 );
50+
4551// / Define lowering patterns for raw buffer ops
4652template <typename GpuOp, typename Intrinsic>
4753struct RawBufferOpLowering : public ConvertOpToLLVMPattern <GpuOp> {
@@ -278,10 +284,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
278284 LogicalResult
279285 matchAndRewrite (LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
280286 ConversionPatternRewriter &rewriter) const override {
281- bool requiresInlineAsm =
282- chipset.majorVersion < 9 ||
283- (chipset.majorVersion == 9 && chipset.minorVersion < 0x0a ) ||
284- (chipset.majorVersion == 11 );
287+ bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11 ;
285288
286289 if (requiresInlineAsm) {
287290 auto asmDialectAttr = LLVM::AsmDialectAttr::get (rewriter.getContext (),
@@ -465,7 +468,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
465468 destElem = destType.getElementType ();
466469
467470 if (sourceElem.isF32 () && destElem.isF32 ()) {
468- if (mfma.getReducePrecision () && chipset. minorVersion >= 0x40 ) {
471+ if (mfma.getReducePrecision () && chipset >= kGfx940 ) {
469472 if (m == 32 && n == 32 && k == 4 && b == 1 )
470473 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName ();
471474 if (m == 16 && n == 16 && k == 8 && b == 1 )
@@ -496,7 +499,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
496499 return ROCDL::mfma_f32_16x16x16f16::getOperationName ();
497500 }
498501
499- if (sourceElem.isBF16 () && destElem.isF32 () && chipset. minorVersion >= 0x0a ) {
502+ if (sourceElem.isBF16 () && destElem.isF32 () && chipset >= kGfx90a ) {
500503 if (m == 32 && n == 32 && k == 4 && b == 2 )
501504 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName ();
502505 if (m == 16 && n == 16 && k == 4 && b == 4 )
@@ -533,21 +536,20 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
533536 return ROCDL::mfma_i32_32x32x8i8::getOperationName ();
534537 if (m == 16 && n == 16 && k == 16 && b == 1 )
535538 return ROCDL::mfma_i32_16x16x16i8::getOperationName ();
536- if (m == 32 && n == 32 && k == 16 && b == 1 && chipset. minorVersion >= 0x40 )
539+ if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940 )
537540 return ROCDL::mfma_i32_32x32x16_i8::getOperationName ();
538- if (m == 16 && n == 16 && k == 32 && b == 1 && chipset. minorVersion >= 0x40 )
541+ if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940 )
539542 return ROCDL::mfma_i32_16x16x32_i8::getOperationName ();
540543 }
541544
542- if (sourceElem.isF64 () && destElem.isF64 () && chipset. minorVersion >= 0x0a ) {
545+ if (sourceElem.isF64 () && destElem.isF64 () && chipset >= kGfx90a ) {
543546 if (m == 16 && n == 16 && k == 4 && b == 1 )
544547 return ROCDL::mfma_f64_16x16x4f64::getOperationName ();
545548 if (m == 4 && n == 4 && k == 4 && b == 4 )
546549 return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
547550 }
548551
549- if (sourceElem.isFloat8E5M2FNUZ () && destElem.isF32 () &&
550- chipset.minorVersion >= 0x40 ) {
552+ if (sourceElem.isFloat8E5M2FNUZ () && destElem.isF32 () && chipset >= kGfx940 ) {
551553 // Known to be correct because there are no scalar f8 instructions and
552554 // because a length mismatch will have been caught by the verifier.
553555 Type sourceBElem =
@@ -566,8 +568,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
566568 }
567569 }
568570
569- if (sourceElem.isFloat8E4M3FNUZ () && destElem.isF32 () &&
570- chipset.minorVersion >= 0x40 ) {
571+ if (sourceElem.isFloat8E4M3FNUZ () && destElem.isF32 () && chipset >= kGfx940 ) {
571572 Type sourceBElem =
572573 cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
573574 if (m == 16 && n == 16 && k == 32 && b == 1 ) {
@@ -631,12 +632,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
631632 if (outVecType.getElementType ().isBF16 ())
632633 intrinsicOutType = outVecType.clone (rewriter.getI16Type ());
633634
634- if (chipset.majorVersion != 9 || chipset. minorVersion < 0x08 )
635+ if (chipset.majorVersion != 9 || chipset < kGfx908 )
635636 return op->emitOpError (" MFMA only supported on gfx908+" );
636637 uint32_t getBlgpField = static_cast <uint32_t >(op.getBlgp ());
637638 if (op.getNegateA () || op.getNegateB () || op.getNegateC ()) {
638- if (chipset. minorVersion < 0x40 )
639- return op.emitOpError (" negation unsupported on older than gfx840 " );
639+ if (chipset < kGfx940 )
640+ return op.emitOpError (" negation unsupported on older than gfx940 " );
640641 getBlgpField |=
641642 op.getNegateA () | (op.getNegateB () << 1 ) | (op.getNegateC () << 2 );
642643 }
@@ -741,7 +742,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
741742 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
742743 ConversionPatternRewriter &rewriter) const {
743744 Location loc = op.getLoc ();
744- if (chipset.majorVersion != 9 || chipset. minorVersion < 0x40 )
745+ if (chipset.majorVersion != 9 || chipset < kGfx940 )
745746 return rewriter.notifyMatchFailure (
746747 loc, " Fp8 conversion instructions are not available on target "
747748 " architecture and their emulation is not implemented" );
@@ -785,7 +786,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
785786 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
786787 ConversionPatternRewriter &rewriter) const {
787788 Location loc = op.getLoc ();
788- if (chipset.majorVersion != 9 || chipset. minorVersion < 0x40 )
789+ if (chipset.majorVersion != 9 || chipset < kGfx940 )
789790 return rewriter.notifyMatchFailure (
790791 loc, " Fp8 conversion instructions are not available on target "
791792 " architecture and their emulation is not implemented" );
@@ -822,7 +823,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
822823 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
823824 ConversionPatternRewriter &rewriter) const {
824825 Location loc = op.getLoc ();
825- if (chipset.majorVersion != 9 || chipset. minorVersion < 0x40 )
826+ if (chipset.majorVersion != 9 || chipset < kGfx940 )
826827 return rewriter.notifyMatchFailure (
827828 loc, " Fp8 conversion instructions are not available on target "
828829 " architecture and their emulation is not implemented" );
0 commit comments