@@ -216,6 +216,18 @@ LogicalResult ConvertFloatToTF32Op::verify() {
216
216
return success ();
217
217
}
218
218
219
+ LogicalResult ConvertF32x2ToF6x2Op::verify () {
220
+ mlir::MLIRContext *ctx = getContext ();
221
+
222
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy ())) {
223
+ return emitOpError (" Only " )
224
+ << mlir::Float6E2M3FNType::get (ctx) << " and "
225
+ << mlir::Float6E3M2FNType::get (ctx)
226
+ << " types are supported for conversions from f32x2 to f6x2." ;
227
+ }
228
+ return success ();
229
+ }
230
+
219
231
LogicalResult ConvertF32x2ToF8x2Op::verify () {
220
232
using RndMode = NVVM::FPRoundingMode;
221
233
using SatMode = NVVM::SaturationMode;
@@ -227,41 +239,67 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
227
239
228
240
bool hasRelu = getRelu ();
229
241
230
- switch (getType ()) {
231
- case ConvertFP8Type::E4M3:
232
- case ConvertFP8Type::E5M2:
233
- if (!isRoundingModeRN)
234
- return emitOpError (" Only RN rounding mode is supported for conversions "
235
- " from f32x2 to .e4m3x2 or .e5m2x2 types" );
236
- if (!isSatFinite)
237
- return emitOpError (" Only SATFINITE saturation mode is supported for "
238
- " conversions from f32x2 to .e4m3x2 or .e5m2x2 types" );
239
- break ;
240
- case ConvertFP8Type::UE8M0:
241
- if (!(isRoundingModeRZ || isRoundingModeRP))
242
- return emitOpError (" Only RZ or RP rounding modes are supported for "
243
- " conversions from f32x2 to .ue8m0x2 type" );
244
- if (hasRelu)
245
- return emitOpError (" relu not supported for conversions to .ue8m0x2 type" );
246
- break ;
247
- }
248
- return success ();
242
+ mlir::MLIRContext *ctx = getContext ();
243
+
244
+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy ())
245
+ .Case <mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
246
+ [&](mlir::Type) -> LogicalResult {
247
+ if (!isRoundingModeRN) {
248
+ return emitOpError (" Only RN rounding mode is supported for "
249
+ " conversions from f32x2 to " )
250
+ << mlir::Float8E4M3FNType::get (ctx) << " and "
251
+ << mlir::Float8E5M2Type::get (ctx) << " types" ;
252
+ }
253
+ if (!isSatFinite) {
254
+ return emitOpError (" Only SATFINITE saturation mode is supported "
255
+ " for conversions "
256
+ " from f32x2 to " )
257
+ << mlir::Float8E4M3FNType::get (ctx) << " and "
258
+ << mlir::Float8E5M2Type::get (ctx) << " types" ;
259
+ }
260
+ return success ();
261
+ })
262
+ .Case <mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
263
+ if (!(isRoundingModeRZ || isRoundingModeRP)) {
264
+ return emitOpError (" Only RZ and RP rounding modes are supported for "
265
+ " conversions from f32x2 to " )
266
+ << mlir::Float8E8M0FNUType::get (ctx) << " type" ;
267
+ }
268
+ if (hasRelu) {
269
+ return emitOpError (" relu not supported for conversions to " )
270
+ << mlir::Float8E8M0FNUType::get (ctx) << " type" ;
271
+ }
272
+ return success ();
273
+ })
274
+ .Default ([&](mlir::Type) {
275
+ return emitOpError (" Only " )
276
+ << mlir::Float8E4M3FNType::get (ctx) << " , "
277
+ << mlir::Float8E5M2Type::get (ctx) << " , and "
278
+ << mlir::Float8E8M0FNUType::get (ctx)
279
+ << " types are "
280
+ " supported for conversions from f32x2 to f8x2" ;
281
+ });
249
282
}
250
283
251
284
LogicalResult ConvertF16x2ToF8x2Op::verify () {
252
- if (getType () == ConvertFP8Type::UE8M0)
253
- return emitOpError (" Only .e4m3 or .e5m2 types are supported for "
254
- " conversions from f16x2 to f8x2." );
285
+ mlir::MLIRContext *ctx = getContext ();
255
286
287
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy ())) {
288
+ return emitOpError (" Only " )
289
+ << mlir::Float8E4M3FNType::get (ctx) << " and "
290
+ << mlir::Float8E5M2Type::get (ctx)
291
+ << " types are supported for conversions from f16x2 to f8x2." ;
292
+ }
256
293
return success ();
257
294
}
258
295
259
296
LogicalResult ConvertBF16x2ToF8x2Op::verify () {
260
297
using RndMode = NVVM::FPRoundingMode;
261
298
262
- if (getType () != ConvertFP8Type::UE8M0)
263
- return emitOpError (
264
- " Only .ue8m0 type is supported for conversions from bf16x2 to f8x2." );
299
+ if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy ()))
300
+ return emitOpError (" Only " ) << mlir::Float8E8M0FNUType::get (getContext ())
301
+ << " type is supported for conversions from "
302
+ " bf16x2 to f8x2." ;
265
303
266
304
auto rnd = getRnd ();
267
305
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
@@ -1980,15 +2018,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1980
2018
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1981
2019
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1982
2020
1983
- llvm::Intrinsic::ID
1984
- ConvertF32x2ToF6x2Op::getIntrinsicID (NVVM::ConvertFP6Type type, bool hasRelu) {
1985
- switch (type) {
1986
- case NVVM::ConvertFP6Type::E2M3:
1987
- return GET_F32x2_TO_F6x2_ID (e2m3x2, hasRelu);
1988
- case NVVM::ConvertFP6Type::E3M2:
1989
- return GET_F32x2_TO_F6x2_ID (e3m2x2, hasRelu);
1990
- }
1991
- llvm_unreachable (" Invalid conversion in ConvertF32x2ToF6x2Op" );
2021
+ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID (mlir::Type dstTy,
2022
+ bool hasRelu) {
2023
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
2024
+ .Case <mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2025
+ return GET_F32x2_TO_F6x2_ID (e2m3x2, hasRelu);
2026
+ })
2027
+ .Case <mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2028
+ return GET_F32x2_TO_F6x2_ID (e3m2x2, hasRelu);
2029
+ })
2030
+ .Default ([](mlir::Type) {
2031
+ llvm_unreachable (" Invalid conversion in ConvertF32x2ToF6x2Op" );
2032
+ return llvm::Intrinsic::not_intrinsic;
2033
+ });
1992
2034
}
1993
2035
1994
2036
#define GET_F32x2_TO_F8X2_US_ID (rnd, has_satf ) \
@@ -2000,41 +2042,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
2000
2042
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
2001
2043
2002
2044
llvm::Intrinsic::ID
2003
- ConvertF32x2ToF8x2Op::getIntrinsicID (NVVM::ConvertFP8Type type,
2004
- NVVM::FPRoundingMode rnd,
2045
+ ConvertF32x2ToF8x2Op::getIntrinsicID (mlir::Type dstTy, NVVM::FPRoundingMode rnd,
2005
2046
NVVM::SaturationMode sat, bool hasRelu) {
2006
2047
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2007
2048
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
2008
2049
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
2009
2050
2010
- switch (type) {
2011
- case NVVM::ConvertFP8Type::E4M3:
2012
- return GET_F32x2_TO_F8X2_S_ID (e4m3x2, hasRelu);
2013
- case NVVM::ConvertFP8Type::E5M2:
2014
- return GET_F32x2_TO_F8X2_S_ID (e5m2x2, hasRelu);
2015
- case NVVM::ConvertFP8Type::UE8M0:
2016
- if (hasRoundingModeRZ)
2017
- return GET_F32x2_TO_F8X2_US_ID (rz, hasSatFinite);
2018
- else if (hasRoundingModeRP)
2019
- return GET_F32x2_TO_F8X2_US_ID (rp, hasSatFinite);
2020
- }
2021
- llvm_unreachable (" Invalid conversion in CvtFloatToF8x2Op" );
2051
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
2052
+ .Case <mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2053
+ return GET_F32x2_TO_F8X2_S_ID (e4m3x2, hasRelu);
2054
+ })
2055
+ .Case <mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2056
+ return GET_F32x2_TO_F8X2_S_ID (e5m2x2, hasRelu);
2057
+ })
2058
+ .Case <mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2059
+ if (hasRoundingModeRZ)
2060
+ return GET_F32x2_TO_F8X2_US_ID (rz, hasSatFinite);
2061
+ else if (hasRoundingModeRP)
2062
+ return GET_F32x2_TO_F8X2_US_ID (rp, hasSatFinite);
2063
+
2064
+ llvm_unreachable (" Invalid conversion in ConvertF32x2ToF8x2Op" );
2065
+ })
2066
+ .Default ([](mlir::Type) {
2067
+ llvm_unreachable (" Invalid conversion in ConvertF32x2ToF8x2Op" );
2068
+ return llvm::Intrinsic::not_intrinsic;
2069
+ });
2022
2070
}
2023
2071
2024
2072
#define GET_F16x2_TO_F8X2_ID (type, has_relu ) \
2025
2073
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
2026
2074
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
2027
2075
2028
- llvm::Intrinsic::ID
2029
- ConvertF16x2ToF8x2Op::getIntrinsicID (NVVM::ConvertFP8Type type, bool hasRelu) {
2030
- switch (type) {
2031
- case NVVM::ConvertFP8Type::E4M3:
2032
- return GET_F16x2_TO_F8X2_ID (e4m3x2, hasRelu);
2033
- case NVVM::ConvertFP8Type::E5M2:
2034
- return GET_F16x2_TO_F8X2_ID (e5m2x2, hasRelu);
2035
- default :
2036
- llvm_unreachable (" Invalid ConvertFP8Type for CvtF16x2ToF8x2Op" );
2037
- }
2076
+ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID (mlir::Type dstTy,
2077
+ bool hasRelu) {
2078
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
2079
+ .Case <mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2080
+ return GET_F16x2_TO_F8X2_ID (e4m3x2, hasRelu);
2081
+ })
2082
+ .Case <mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2083
+ return GET_F16x2_TO_F8X2_ID (e5m2x2, hasRelu);
2084
+ })
2085
+ .Default ([](mlir::Type) {
2086
+ llvm_unreachable (" Invalid conversion in ConvertF16x2ToF8x2Op" );
2087
+ return llvm::Intrinsic::not_intrinsic;
2088
+ });
2038
2089
}
2039
2090
2040
2091
#define GET_BF16X2_TO_F8X2_ID (rnd, has_satf ) \
0 commit comments