@@ -601,6 +601,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
601601 }
602602}
603603
604+ // / Return true if `type` is the E5M2 variant of an 8-bit float that is
605+ // / supported by the `_bf8` instructions on the given `chipset`.
606+ static bool typeIsExpectedBf8ForChipset (Chipset chipset, Type type) {
607+ return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
608+ (hasOcpFp8 (chipset) && isa<Float8E5M2Type>(type));
609+ }
610+
611+ // / Return true if `type` is the E4M3FN variant of an 8-bit float that is
612+ // / supported by the `_fp8` instructions on the given `chipset`.
613+ static bool typeIsExpectedFp8ForChipset (Chipset chipset, Type type) {
614+ return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
615+ (hasOcpFp8 (chipset) && isa<Float8E4M3FNType>(type));
616+ }
617+
604618// / Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
605619// / if one exists. This includes checking to ensure the intrinsic is supported
606620// / on the architecture you are compiling for.
@@ -697,40 +711,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
697711 return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
698712 }
699713
700- if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32 () &&
701- chipset >= kGfx942 ) {
714+ if (destElem.isF32 () && typeIsExpectedBf8ForChipset (chipset, sourceElem)) {
702715 // Known to be correct because there are no scalar f8 instructions and
703716 // because a length mismatch will have been caught by the verifier.
704717 Type sourceBElem =
705718 cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
706719 if (m == 16 && n == 16 && k == 32 && b == 1 ) {
707- if (isa<Float8E5M2FNUZType>( sourceBElem))
720+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
708721 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
709- if (isa<Float8E4M3FNUZType>( sourceBElem))
722+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
710723 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
711724 }
712725 if (m == 32 && n == 32 && k == 16 && b == 1 ) {
713- if (isa<Float8E5M2FNUZType>( sourceBElem))
726+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
714727 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
715- if (isa<Float8E4M3FNUZType>( sourceBElem))
728+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
716729 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
717730 }
718731 }
719732
720- if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32 () &&
721- chipset >= kGfx942 ) {
733+ if (destElem.isF32 () && typeIsExpectedFp8ForChipset (chipset, sourceElem)) {
722734 Type sourceBElem =
723735 cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
724736 if (m == 16 && n == 16 && k == 32 && b == 1 ) {
725- if (isa<Float8E5M2FNUZType>( sourceBElem))
737+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
726738 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
727- if (isa<Float8E4M3FNUZType>( sourceBElem))
739+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
728740 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
729741 }
730742 if (m == 32 && n == 32 && k == 16 && b == 1 ) {
731- if (isa<Float8E5M2FNUZType>( sourceBElem))
743+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
732744 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
733- if (isa<Float8E4M3FNUZType>( sourceBElem))
745+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
734746 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
735747 }
736748 }
@@ -936,7 +948,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
936948 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
937949 ConversionPatternRewriter &rewriter) const {
938950 Location loc = op.getLoc ();
939- if (chipset. majorVersion != 9 || chipset < kGfx942 )
951+ if (!( chipset == kGfx942 || hasOcpFp8 ( chipset)) )
940952 return rewriter.notifyMatchFailure (
941953 loc, " Fp8 conversion instructions are not available on target "
942954 " architecture and their emulation is not implemented" );
@@ -966,10 +978,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
966978 }
967979 Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
968980 Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
969- if (isa<Float8E5M2FNUZType>( sourceElemType)) {
981+ if (typeIsExpectedBf8ForChipset (chipset, sourceElemType)) {
970982 rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
971983 wordSel);
972- } else if (isa<Float8E4M3FNUZType>( sourceElemType)) {
984+ } else if (typeIsExpectedFp8ForChipset (chipset, sourceElemType)) {
973985 rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
974986 wordSel);
975987 }
@@ -980,7 +992,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
980992 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
981993 ConversionPatternRewriter &rewriter) const {
982994 Location loc = op.getLoc ();
983- if (chipset. majorVersion != 9 || chipset < kGfx942 )
995+ if (!( chipset == kGfx942 || hasOcpFp8 ( chipset)) )
984996 return rewriter.notifyMatchFailure (
985997 loc, " Fp8 conversion instructions are not available on target "
986998 " architecture and their emulation is not implemented" );
@@ -1001,10 +1013,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
10011013 Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
10021014
10031015 Value result;
1004- if (isa<Float8E5M2FNUZType>( resultElemType))
1016+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
10051017 result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
10061018 existing, wordSel);
1007- else if (isa<Float8E4M3FNUZType>( resultElemType))
1019+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
10081020 result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
10091021 existing, wordSel);
10101022
@@ -1017,7 +1029,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
10171029 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
10181030 ConversionPatternRewriter &rewriter) const {
10191031 Location loc = op.getLoc ();
1020- if (chipset. majorVersion != 9 || chipset < kGfx942 )
1032+ if (!( chipset == kGfx942 || hasOcpFp8 ( chipset)) )
10211033 return rewriter.notifyMatchFailure (
10221034 loc, " Fp8 conversion instructions are not available on target "
10231035 " architecture and their emulation is not implemented" );
@@ -1036,10 +1048,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
10361048 Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
10371049
10381050 Value result;
1039- if (isa<Float8E5M2FNUZType>( resultElemType))
1051+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
10401052 result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
10411053 existing, byteSel);
1042- else if (isa<Float8E4M3FNUZType>( resultElemType))
1054+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
10431055 result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
10441056 existing, byteSel);
10451057
0 commit comments