diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 27d54e7abeda9..6540273b216e3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1079,7 +1079,7 @@ def CVTFP6TypeAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } -def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { +def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2.to.f6x2"> { let summary = "Convert a pair of float inputs to f6x2"; let description = [{ This Op converts each of the given float inputs to the specified fp6 type. @@ -1096,6 +1096,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) }]; + let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins CVTFP6TypeAttr:$type, @@ -1110,7 +1111,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { }]; string llvmBuilder = [{ - auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu); + auto intId = NVVM::CvtF32x2ToF6x2Op::getIntrinsicID($type, $relu); llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); if(op.getDst().getType().isInteger(16)) $dst = packedI16; @@ -1120,6 +1121,153 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { }]; } +def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">; +def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">; +def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">; + +def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind", + [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def CVTFP8TypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> { + let summary = "Convert a pair of float inputs to f8x2"; + let description = [{ + This Op converts each of the given float inputs to the specified fp8 type. + The result `dst` is represented as an i16 type or as a vector + of two i8 types. + If `dst` is returned as an i16 type, the converted values are packed such + that the value converted from `a` is stored in the upper 8 bits of `dst` + and the value converted from `b` is stored in the lower 8 bits of `dst`. + If `dst` is returned as a vector type, each converted value is stored as an + i8 element in the vector. + The `rnd` and `sat` attributes specify the rounding and saturation modes respectively. + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let hasVerifier = 1; + let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); + let arguments = (ins + CVTFP8TypeAttr:$type, + F32:$a, + F32:$b, + DefaultValuedAttr:$rnd, + DefaultValuedAttr:$sat, + DefaultValuedAttr:$relu); + let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)"; + + let extraClassDeclaration = [{ + static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to, + NVVM::FPRoundingMode rnd, + NVVM::SaturationMode sat, + bool hasRelu); + }]; + + string llvmBuilder = [{ + auto intId = NVVM::CvtF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); + if(op.getDst().getType().isInteger(16)) + $dst = packedI16; + else + $dst = builder.CreateBitCast(packedI16, + llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); + }]; +} + +def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> { + let summary = "Convert an f16x2 input to f8x2"; + let description = [{ + This Op converts the given f16 inputs in an f16x2 vector to the specified + f8 type. + The result `dst` is represented as an i16 type or as a vector + of two i8 types. + If `dst` is returned as an i16 type, the converted values from `a` + are packed such that the value converted from the first element of `a` + is stored in the upper 8 bits of `dst` and the value converted from the + second element of `a` is stored in the lower 8 bits of `dst`. + If `dst` is returned as a vector type, each converted value is stored as an + i8 element in the vector. + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let hasVerifier = 1; + let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); + let arguments = (ins + CVTFP8TypeAttr:$type, + VectorOfLengthAndType<[2], [F16]>:$a, + DefaultValuedAttr:$relu); + let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)"; + + let extraClassDeclaration = [{ + static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to, + bool hasRelu); + }]; + + string llvmBuilder = [{ + auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a}); + if(op.getDst().getType().isInteger(16)) + $dst = packedI16; + else + $dst = builder.CreateBitCast(packedI16, + llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); + }]; +} + +def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> { + let summary = "Convert a pair of bf16 inputs to f8x2"; + let description = [{ + This Op converts the given bf16 inputs in a bf16x2 vector to the specified + f8 type. + The result `dst` is represented as an i16 type or as a vector + of two i8 types. + If `dst` is returned as an i16 type, the converted values from `a` + are packed such that the value converted from the first element of `a` + is stored in the upper 8 bits of `dst` and the value converted from the + second element of `a` is stored in the lower 8 bits of `dst`. + If `dst` is returned as a vector type, each converted value is stored as an + i8 element in the vector. + The `rnd` and `sat` attributes specify the rounding and saturation modes + respectively. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let hasVerifier = 1; + let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); + let arguments = (ins + CVTFP8TypeAttr:$type, + VectorOfLengthAndType<[2], [BF16]>:$a, + DefaultValuedAttr:$rnd, + DefaultValuedAttr:$sat); + let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)"; + + let extraClassDeclaration = [{ + static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd, + NVVM::SaturationMode sat); + }]; + + string llvmBuilder = [{ + auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a}); + if(op.getDst().getType().isInteger(16)) + $dst = packedI16; + else + $dst = builder.CreateBitCast(packedI16, + llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); + }]; +} + //===----------------------------------------------------------------------===// // NVVM MMA Ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 18453aa7f6ea9..3c3731a63e268 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -133,6 +133,61 @@ LogicalResult CvtFloatToTF32Op::verify() { return success(); } +LogicalResult CvtF32x2ToF8x2Op::verify() { + using RndMode = NVVM::FPRoundingMode; + using SatMode = NVVM::SaturationMode; + + bool isRoundingModeRN = getRnd() == RndMode::RN; + bool isRoundingModeRZ = getRnd() == RndMode::RZ; + bool isRoundingModeRP = getRnd() == RndMode::RP; + bool isSatFinite = getSat() == SatMode::SATFINITE; + + bool hasRelu = getRelu(); + + switch (getType()) { + case CVTFP8Type::E4M3: + case CVTFP8Type::E5M2: + if (!isRoundingModeRN) + return emitOpError("Only RN rounding mode is supported for conversions " + "from f32x2 to .e4m3x2 or .e5m2x2 types"); + if (!isSatFinite) + return emitOpError("Only SATFINITE saturation mode is supported for " + "conversions from f32x2 to .e4m3x2 or .e5m2x2 types"); + break; + case CVTFP8Type::UE8M0: + if (!(isRoundingModeRZ || isRoundingModeRP)) + return emitOpError("Only RZ or RP rounding modes are supported for " + "conversions from f32x2 to .ue8m0x2 type"); + if (hasRelu) + return emitOpError("relu not supported for conversions to .ue8m0x2 type"); + break; + } + return success(); +} + +LogicalResult CvtF16x2ToF8x2Op::verify() { + if (getType() == CVTFP8Type::UE8M0) + return emitOpError("Only .e4m3 or .e5m2 types are supported for " + "conversions from f16x2 to f8x2."); + + return success(); +} + +LogicalResult CvtBF16x2ToF8x2Op::verify() { + using RndMode = NVVM::FPRoundingMode; + + if (getType() != CVTFP8Type::UE8M0) + return emitOpError( + "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2."); + + auto rnd = getRnd(); + if (!(rnd == RndMode::RZ || rnd == RndMode::RP)) + return emitOpError("Only RZ and RP rounding modes are supported for " + "conversions from bf16x2 to f8x2."); + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -1290,17 +1345,81 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } -#define CVT_TO_F6X2_ID_IMPL(type, has_relu) \ +#define GET_F32x2_TO_F6x2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite -llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type, - bool hasRelu) { +llvm::Intrinsic::ID CvtF32x2ToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type, + bool hasRelu) { switch (type) { case NVVM::CVTFP6Type::E2M3: - return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu); + return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); case NVVM::CVTFP6Type::E3M2: - return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu); + return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); + } +} + +#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \ + has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \ + : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd + +#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \ + has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \ + : llvm::Intrinsic::nvvm_ff_to_##type##_rn + +llvm::Intrinsic::ID CvtF32x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type, + NVVM::FPRoundingMode rnd, + NVVM::SaturationMode sat, + bool hasRelu) { + bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); + bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ); + bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP); + + switch (type) { + case NVVM::CVTFP8Type::E4M3: + return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); + case NVVM::CVTFP8Type::E5M2: + return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); + case NVVM::CVTFP8Type::UE8M0: + if (hasRoundingModeRZ) + return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); + else if (hasRoundingModeRP) + return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); + } + llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op"); +} + +#define GET_F16x2_TO_F8X2_ID(type, has_relu) \ + has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \ + : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn + +llvm::Intrinsic::ID CvtF16x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type, + bool hasRelu) { + switch (type) { + case NVVM::CVTFP8Type::E4M3: + return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); + case NVVM::CVTFP8Type::E5M2: + return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); + default: + llvm_unreachable("Invalid CVTFP8Type for CvtF16x2ToF8x2Op"); + } +} + +#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \ + has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \ + : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd + +llvm::Intrinsic::ID +CvtBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, + NVVM::SaturationMode sat) { + bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); + switch (rnd) { + case NVVM::FPRoundingMode::RZ: + return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite); + case NVVM::FPRoundingMode::RP: + return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite); + default: + llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op"); } } diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir index 2237e6faad52d..8ccc656e57e1c 100644 --- a/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir @@ -1,22 +1,21 @@ // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s -// CHECK-LABEL: @convert_float_to_fp6x2_packed -llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) { +// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed +llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) { //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) - %res1 = nvvm.cvt.to.f6x2 %srcA, %srcB : i16 + %res1 = nvvm.cvt.f32x2.to.f6x2 %srcA, %srcB : i16 //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.cvt.to.f6x2 %srcA, %srcB : i16 + %res2 = nvvm.cvt.f32x2.to.f6x2 %srcA, %srcB : i16 llvm.return } -// CHECK-LABEL: @convert_float_to_fp6x2_vector -llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) { +// CHECK-LABEL: @convert_f32x2_to_fp6x2_vector +llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) { //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8> - %res1 = nvvm.cvt.to.f6x2 %srcA, %srcB : vector<2xi8> + %res1 = nvvm.cvt.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> - %res2 = nvvm.cvt.to.f6x2 %srcA, %srcB : vector<2xi8> + %res2 = nvvm.cvt.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> llvm.return } - diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir new file mode 100644 index 0000000000000..8ea0bbabe4d0a --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// ----- + +// CHECK-LABEL: @convert_f32x2_to_f8x2_e4m3 +llvm.func @convert_f32x2_to_f8x2_e4m3(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_f8x2_e5m2 +llvm.func @convert_f32x2_to_f8x2_e5m2(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_f8x2_ue8m0 +llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz.satfinite(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_f8x2_vector_return +llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) { + // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> + %res1 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> + // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> + %res2 = nvvm.cvt.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> + llvm.return +} + +// ----- + +// CHECK-LABEL: @convert_f16x2_to_f8x2_e4m3 +llvm.func @convert_f16x2_to_f8x2_e4m3(%src : vector<2xf16>) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}}) + %res1 = nvvm.cvt.f16x2.to.f8x2 %src : vector<2xf16> -> i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}}) + %res2 = nvvm.cvt.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 + llvm.return +} + +// CHECK-LABEL: @convert_f16x2_to_f8x2_e5m2 +llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}}) + %res1 = nvvm.cvt.f16x2.to.f8x2 %src : vector<2xf16> -> i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}}) + %res2 = nvvm.cvt.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 + llvm.return +} + +// CHECK-LABEL: @convert_f16x2_to_f8x2_vector_return +llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) { + // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> + %res1 = nvvm.cvt.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> + // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> + %res2 = nvvm.cvt.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> + llvm.return +} + +// ----- + +// CHECK-LABEL: @convert_bf16x2_to_f8x2_ue8m0 +llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}}) + %res1 = nvvm.cvt.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}}) + %res2 = nvvm.cvt.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz.satfinite(<2 x bfloat> %{{.*}}) + %res3 = nvvm.cvt.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}}) + %res4 = nvvm.cvt.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 + llvm.return +} + +// CHECK-LABEL: @convert_bf16x2_to_f8x2_vector_return +llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) { + // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> + %res1 = nvvm.cvt.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> vector<2xi8> + // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> + %res2 = nvvm.cvt.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index f87f11daeef54..accec9c7af4f2 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -176,3 +176,75 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) { %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)> llvm.return } + +// ----- + +llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) { + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} + %res = nvvm.cvt.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) { + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} + %res = nvvm.cvt.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) { + // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to .ue8m0x2 type}} + %res = nvvm.cvt.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : i16 + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) { + // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} + %res = nvvm.cvt.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) { + // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} + %res = nvvm.cvt.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) { + // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}} + %res = nvvm.cvt.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, relu = true} : i16 + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) { + // expected-error @below {{Only .e4m3 or .e5m2 types are supported for conversions from f16x2 to f8x2.}} + %res = nvvm.cvt.f16x2.to.f8x2 %src : vector<2xf16> -> i16 + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) { + // expected-error @below {{Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.}} + %res = nvvm.cvt.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) { + // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}} + %res = nvvm.cvt.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + llvm.return +}