@@ -189,6 +189,14 @@ LogicalResult ConvertFloatToTF32Op::verify() {
189189 return success ();
190190}
191191
192+ LogicalResult ConvertF32x2ToF6x2Op::verify () {
193+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy ())) {
194+ return emitError (" Only f6E2M3FN and f6E3M2FN types are supported for "
195+ " ConvertF32x2ToF6x2Op." );
196+ }
197+ return success ();
198+ }
199+
192200LogicalResult ConvertF32x2ToF8x2Op::verify () {
193201 using RndMode = NVVM::FPRoundingMode;
194202 using SatMode = NVVM::SaturationMode;
@@ -200,41 +208,52 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
200208
201209 bool hasRelu = getRelu ();
202210
203- switch (getType ()) {
204- case ConvertFP8Type::E4M3:
205- case ConvertFP8Type::E5M2:
206- if (!isRoundingModeRN)
207- return emitOpError (" Only RN rounding mode is supported for conversions "
208- " from f32x2 to .e4m3x2 or .e5m2x2 types" );
209- if (!isSatFinite)
210- return emitOpError (" Only SATFINITE saturation mode is supported for "
211- " conversions from f32x2 to .e4m3x2 or .e5m2x2 types" );
212- break ;
213- case ConvertFP8Type::UE8M0:
214- if (!(isRoundingModeRZ || isRoundingModeRP))
215- return emitOpError (" Only RZ or RP rounding modes are supported for "
216- " conversions from f32x2 to .ue8m0x2 type" );
217- if (hasRelu)
218- return emitOpError (" relu not supported for conversions to .ue8m0x2 type" );
219- break ;
220- }
221- return success ();
211+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy ())
212+ .Case <mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
213+ [&](mlir::Type) -> LogicalResult {
214+ if (!isRoundingModeRN) {
215+ return emitOpError (
216+ " Only RN rounding mode is supported for conversions from "
217+ " f32x2 to f8E4M3FNx2 or f8E5M2x2 types" );
218+ }
219+ if (!isSatFinite) {
220+ return emitOpError (
221+ " Only SATFINITE saturation mode is supported for conversions "
222+ " from f32x2 to f8E4M3FNx2 or f8E5M2x2 types" );
223+ }
224+ return success ();
225+ })
226+ .Case <mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
227+ if (!(isRoundingModeRZ || isRoundingModeRP)) {
228+ return emitOpError (" Only RZ or RP rounding modes are supported for "
229+ " conversions from f32x2 to f8E8M0FNUx2 type" );
230+ }
231+ if (hasRelu) {
232+ return emitOpError (
233+ " relu not supported for conversions to f8E8M0FNUx2 type" );
234+ }
235+ return success ();
236+ })
237+ .Default ([this ](mlir::Type) {
238+ return emitOpError (" Only f8e4m3fn, f8e5m2, and f8e8m0fnu types are "
239+ " supported for conversions from f32x2 to f8x2" );
240+ });
222241}
223242
224243LogicalResult ConvertF16x2ToF8x2Op::verify () {
225- if (getType () == ConvertFP8Type::UE8M0)
226- return emitOpError (" Only .e4m3 or .e5m2 types are supported for "
244+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>( getDstTy ())) {
245+ return emitOpError (" Only f8E4M3FN or f8E5M2 types are supported for "
227246 " conversions from f16x2 to f8x2." );
228-
247+ }
229248 return success ();
230249}
231250
232251LogicalResult ConvertBF16x2ToF8x2Op::verify () {
233252 using RndMode = NVVM::FPRoundingMode;
234253
235- if (getType () != ConvertFP8Type::UE8M0 )
236- return emitOpError (
237- " Only .ue8m0 type is supported for conversions from bf16x2 to f8x2." );
254+ if (!llvm::isa<mlir::Float8E8M0FNUType>( getDstTy ()) )
255+ return emitOpError (" Only f8E8M0FNU type is supported for conversions from "
256+ " bf16x2 to f8x2." );
238257
239258 auto rnd = getRnd ();
240259 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
@@ -1714,15 +1733,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
17141733 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
17151734 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
17161735
1717- llvm::Intrinsic::ID
1718- ConvertF32x2ToF6x2Op::getIntrinsicID (NVVM::ConvertFP6Type type, bool hasRelu) {
1719- switch (type) {
1720- case NVVM::ConvertFP6Type::E2M3:
1721- return GET_F32x2_TO_F6x2_ID (e2m3x2, hasRelu);
1722- case NVVM::ConvertFP6Type::E3M2:
1723- return GET_F32x2_TO_F6x2_ID (e3m2x2, hasRelu);
1724- }
1725- llvm_unreachable (" Invalid conversion in ConvertF32x2ToF6x2Op" );
1736+ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID (mlir::Type dstTy,
1737+ bool hasRelu) {
1738+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
1739+ .Case <mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
1740+ return GET_F32x2_TO_F6x2_ID (e2m3x2, hasRelu);
1741+ })
1742+ .Case <mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
1743+ return GET_F32x2_TO_F6x2_ID (e3m2x2, hasRelu);
1744+ })
1745+ .Default ([](mlir::Type) {
1746+ llvm_unreachable (" Invalid conversion in ConvertF32x2ToF6x2Op" );
1747+ return llvm::Intrinsic::not_intrinsic;
1748+ });
17261749}
17271750
17281751#define GET_F32x2_TO_F8X2_US_ID (rnd, has_satf ) \
@@ -1734,41 +1757,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
17341757 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
17351758
17361759llvm::Intrinsic::ID
1737- ConvertF32x2ToF8x2Op::getIntrinsicID (NVVM::ConvertFP8Type type,
1738- NVVM::FPRoundingMode rnd,
1760+ ConvertF32x2ToF8x2Op::getIntrinsicID (mlir::Type dstTy, NVVM::FPRoundingMode rnd,
17391761 NVVM::SaturationMode sat, bool hasRelu) {
17401762 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
17411763 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
17421764 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
17431765
1744- switch (type) {
1745- case NVVM::ConvertFP8Type::E4M3:
1746- return GET_F32x2_TO_F8X2_S_ID (e4m3x2, hasRelu);
1747- case NVVM::ConvertFP8Type::E5M2:
1748- return GET_F32x2_TO_F8X2_S_ID (e5m2x2, hasRelu);
1749- case NVVM::ConvertFP8Type::UE8M0:
1750- if (hasRoundingModeRZ)
1751- return GET_F32x2_TO_F8X2_US_ID (rz, hasSatFinite);
1752- else if (hasRoundingModeRP)
1753- return GET_F32x2_TO_F8X2_US_ID (rp, hasSatFinite);
1754- }
1755- llvm_unreachable (" Invalid conversion in CvtFloatToF8x2Op" );
1766+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
1767+ .Case <mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
1768+ return GET_F32x2_TO_F8X2_S_ID (e4m3x2, hasRelu);
1769+ })
1770+ .Case <mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
1771+ return GET_F32x2_TO_F8X2_S_ID (e5m2x2, hasRelu);
1772+ })
1773+ .Case <mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
1774+ if (hasRoundingModeRZ)
1775+ return GET_F32x2_TO_F8X2_US_ID (rz, hasSatFinite);
1776+ else if (hasRoundingModeRP)
1777+ return GET_F32x2_TO_F8X2_US_ID (rp, hasSatFinite);
1778+
1779+ llvm_unreachable (" Invalid conversion in ConvertF32x2ToF8x2Op" );
1780+ })
1781+ .Default ([](mlir::Type) {
1782+ llvm_unreachable (" Invalid conversion in ConvertF32x2ToF8x2Op" );
1783+ return llvm::Intrinsic::not_intrinsic;
1784+ });
17561785}
17571786
17581787#define GET_F16x2_TO_F8X2_ID (type, has_relu ) \
17591788 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
17601789 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
17611790
1762- llvm::Intrinsic::ID
1763- ConvertF16x2ToF8x2Op::getIntrinsicID (NVVM::ConvertFP8Type type, bool hasRelu) {
1764- switch (type) {
1765- case NVVM::ConvertFP8Type::E4M3:
1766- return GET_F16x2_TO_F8X2_ID (e4m3x2, hasRelu);
1767- case NVVM::ConvertFP8Type::E5M2:
1768- return GET_F16x2_TO_F8X2_ID (e5m2x2, hasRelu);
1769- default :
1770- llvm_unreachable (" Invalid ConvertFP8Type for CvtF16x2ToF8x2Op" );
1771- }
1791+ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID (mlir::Type dstTy,
1792+ bool hasRelu) {
1793+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
1794+ .Case <mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
1795+ return GET_F16x2_TO_F8X2_ID (e4m3x2, hasRelu);
1796+ })
1797+ .Case <mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
1798+ return GET_F16x2_TO_F8X2_ID (e5m2x2, hasRelu);
1799+ })
1800+ .Default ([](mlir::Type) {
1801+ llvm_unreachable (" Invalid conversion in ConvertF16x2ToF8x2Op" );
1802+ return llvm::Intrinsic::not_intrinsic;
1803+ });
17721804}
17731805
17741806#define GET_BF16X2_TO_F8X2_ID (rnd, has_satf ) \
0 commit comments