@@ -216,6 +216,18 @@ LogicalResult ConvertFloatToTF32Op::verify() {
216216 return success ();
217217}
218218
219+ LogicalResult ConvertF32x2ToF6x2Op::verify () {
220+ mlir::MLIRContext *ctx = getContext ();
221+
222+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy ())) {
223+ return emitOpError (" Only " )
224+ << mlir::Float6E2M3FNType::get (ctx) << " and "
225+ << mlir::Float6E3M2FNType::get (ctx)
226+ << " types are supported for conversions from f32x2 to f6x2." ;
227+ }
228+ return success ();
229+ }
230+
219231LogicalResult ConvertF32x2ToF8x2Op::verify () {
220232 using RndMode = NVVM::FPRoundingMode;
221233 using SatMode = NVVM::SaturationMode;
@@ -227,41 +239,67 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
227239
228240 bool hasRelu = getRelu ();
229241
230- switch (getType ()) {
231- case ConvertFP8Type::E4M3:
232- case ConvertFP8Type::E5M2:
233- if (!isRoundingModeRN)
234- return emitOpError (" Only RN rounding mode is supported for conversions "
235- " from f32x2 to .e4m3x2 or .e5m2x2 types" );
236- if (!isSatFinite)
237- return emitOpError (" Only SATFINITE saturation mode is supported for "
238- " conversions from f32x2 to .e4m3x2 or .e5m2x2 types" );
239- break ;
240- case ConvertFP8Type::UE8M0:
241- if (!(isRoundingModeRZ || isRoundingModeRP))
242- return emitOpError (" Only RZ or RP rounding modes are supported for "
243- " conversions from f32x2 to .ue8m0x2 type" );
244- if (hasRelu)
245- return emitOpError (" relu not supported for conversions to .ue8m0x2 type" );
246- break ;
247- }
248- return success ();
242+ mlir::MLIRContext *ctx = getContext ();
243+
244+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy ())
245+ .Case <mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
246+ [&](mlir::Type) -> LogicalResult {
247+ if (!isRoundingModeRN) {
248+ return emitOpError (" Only RN rounding mode is supported for "
249+ " conversions from f32x2 to " )
250+ << mlir::Float8E4M3FNType::get (ctx) << " and "
251+ << mlir::Float8E5M2Type::get (ctx) << " types" ;
252+ }
253+ if (!isSatFinite) {
254+ return emitOpError (" Only SATFINITE saturation mode is supported "
255+ " for conversions "
256+ " from f32x2 to " )
257+ << mlir::Float8E4M3FNType::get (ctx) << " and "
258+ << mlir::Float8E5M2Type::get (ctx) << " types" ;
259+ }
260+ return success ();
261+ })
262+ .Case <mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
263+ if (!(isRoundingModeRZ || isRoundingModeRP)) {
264+ return emitOpError (" Only RZ and RP rounding modes are supported for "
265+ " conversions from f32x2 to " )
266+ << mlir::Float8E8M0FNUType::get (ctx) << " type" ;
267+ }
268+ if (hasRelu) {
269+ return emitOpError (" relu not supported for conversions to " )
270+ << mlir::Float8E8M0FNUType::get (ctx) << " type" ;
271+ }
272+ return success ();
273+ })
274+ .Default ([&](mlir::Type) {
275+ return emitOpError (" Only " )
276+ << mlir::Float8E4M3FNType::get (ctx) << " , "
277+ << mlir::Float8E5M2Type::get (ctx) << " , and "
278+ << mlir::Float8E8M0FNUType::get (ctx)
279+ << " types are "
280+ " supported for conversions from f32x2 to f8x2" ;
281+ });
249282}
250283
251284LogicalResult ConvertF16x2ToF8x2Op::verify () {
252- if (getType () == ConvertFP8Type::UE8M0)
253- return emitOpError (" Only .e4m3 or .e5m2 types are supported for "
254- " conversions from f16x2 to f8x2." );
285+ mlir::MLIRContext *ctx = getContext ();
255286
287+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy ())) {
288+ return emitOpError (" Only " )
289+ << mlir::Float8E4M3FNType::get (ctx) << " and "
290+ << mlir::Float8E5M2Type::get (ctx)
291+ << " types are supported for conversions from f16x2 to f8x2." ;
292+ }
256293 return success ();
257294}
258295
259296LogicalResult ConvertBF16x2ToF8x2Op::verify () {
260297 using RndMode = NVVM::FPRoundingMode;
261298
262- if (getType () != ConvertFP8Type::UE8M0)
263- return emitOpError (
264- " Only .ue8m0 type is supported for conversions from bf16x2 to f8x2." );
299+ if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy ()))
300+ return emitOpError (" Only " ) << mlir::Float8E8M0FNUType::get (getContext ())
301+ << " type is supported for conversions from "
302+ " bf16x2 to f8x2." ;
265303
266304 auto rnd = getRnd ();
267305 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
@@ -1980,15 +2018,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
19802018 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
19812019 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
19822020
1983- llvm::Intrinsic::ID
1984- ConvertF32x2ToF6x2Op::getIntrinsicID (NVVM::ConvertFP6Type type, bool hasRelu) {
1985- switch (type) {
1986- case NVVM::ConvertFP6Type::E2M3:
1987- return GET_F32x2_TO_F6x2_ID (e2m3x2, hasRelu);
1988- case NVVM::ConvertFP6Type::E3M2:
1989- return GET_F32x2_TO_F6x2_ID (e3m2x2, hasRelu);
1990- }
1991- llvm_unreachable (" Invalid conversion in ConvertF32x2ToF6x2Op" );
2021+ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID (mlir::Type dstTy,
2022+ bool hasRelu) {
2023+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
2024+ .Case <mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2025+ return GET_F32x2_TO_F6x2_ID (e2m3x2, hasRelu);
2026+ })
2027+ .Case <mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2028+ return GET_F32x2_TO_F6x2_ID (e3m2x2, hasRelu);
2029+ })
2030+ .Default ([](mlir::Type) {
2031+ llvm_unreachable (" Invalid conversion in ConvertF32x2ToF6x2Op" );
2032+ return llvm::Intrinsic::not_intrinsic;
2033+ });
19922034}
19932035
19942036#define GET_F32x2_TO_F8X2_US_ID (rnd, has_satf ) \
@@ -2000,41 +2042,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
20002042 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
20012043
20022044llvm::Intrinsic::ID
2003- ConvertF32x2ToF8x2Op::getIntrinsicID (NVVM::ConvertFP8Type type,
2004- NVVM::FPRoundingMode rnd,
2045+ ConvertF32x2ToF8x2Op::getIntrinsicID (mlir::Type dstTy, NVVM::FPRoundingMode rnd,
20052046 NVVM::SaturationMode sat, bool hasRelu) {
20062047 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
20072048 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
20082049 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
20092050
2010- switch (type) {
2011- case NVVM::ConvertFP8Type::E4M3:
2012- return GET_F32x2_TO_F8X2_S_ID (e4m3x2, hasRelu);
2013- case NVVM::ConvertFP8Type::E5M2:
2014- return GET_F32x2_TO_F8X2_S_ID (e5m2x2, hasRelu);
2015- case NVVM::ConvertFP8Type::UE8M0:
2016- if (hasRoundingModeRZ)
2017- return GET_F32x2_TO_F8X2_US_ID (rz, hasSatFinite);
2018- else if (hasRoundingModeRP)
2019- return GET_F32x2_TO_F8X2_US_ID (rp, hasSatFinite);
2020- }
2021- llvm_unreachable (" Invalid conversion in CvtFloatToF8x2Op" );
2051+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
2052+ .Case <mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2053+ return GET_F32x2_TO_F8X2_S_ID (e4m3x2, hasRelu);
2054+ })
2055+ .Case <mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2056+ return GET_F32x2_TO_F8X2_S_ID (e5m2x2, hasRelu);
2057+ })
2058+ .Case <mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2059+ if (hasRoundingModeRZ)
2060+ return GET_F32x2_TO_F8X2_US_ID (rz, hasSatFinite);
2061+ else if (hasRoundingModeRP)
2062+ return GET_F32x2_TO_F8X2_US_ID (rp, hasSatFinite);
2063+
2064+ llvm_unreachable (" Invalid conversion in ConvertF32x2ToF8x2Op" );
2065+ })
2066+ .Default ([](mlir::Type) {
2067+ llvm_unreachable (" Invalid conversion in ConvertF32x2ToF8x2Op" );
2068+ return llvm::Intrinsic::not_intrinsic;
2069+ });
20222070}
20232071
20242072#define GET_F16x2_TO_F8X2_ID (type, has_relu ) \
20252073 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
20262074 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
20272075
2028- llvm::Intrinsic::ID
2029- ConvertF16x2ToF8x2Op::getIntrinsicID (NVVM::ConvertFP8Type type, bool hasRelu) {
2030- switch (type) {
2031- case NVVM::ConvertFP8Type::E4M3:
2032- return GET_F16x2_TO_F8X2_ID (e4m3x2, hasRelu);
2033- case NVVM::ConvertFP8Type::E5M2:
2034- return GET_F16x2_TO_F8X2_ID (e5m2x2, hasRelu);
2035- default :
2036- llvm_unreachable (" Invalid ConvertFP8Type for CvtF16x2ToF8x2Op" );
2037- }
2076+ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID (mlir::Type dstTy,
2077+ bool hasRelu) {
2078+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
2079+ .Case <mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2080+ return GET_F16x2_TO_F8X2_ID (e4m3x2, hasRelu);
2081+ })
2082+ .Case <mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2083+ return GET_F16x2_TO_F8X2_ID (e5m2x2, hasRelu);
2084+ })
2085+ .Default ([](mlir::Type) {
2086+ llvm_unreachable (" Invalid conversion in ConvertF16x2ToF8x2Op" );
2087+ return llvm::Intrinsic::not_intrinsic;
2088+ });
20382089}
20392090
20402091#define GET_BF16X2_TO_F8X2_ID (rnd, has_satf ) \
0 commit comments