diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 8537c7030aa8f..c540c5ccf50bf 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Dialect/LLVMIR/LLVMTypes.td" +include "mlir/IR/CommonAttrConstraints.td" def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>; def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; @@ -1258,18 +1259,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> { }]; } -def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">; -def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">; - -def ConvertFP6Type : I32EnumAttr<"ConvertFP6Type", "NVVM ConvertFP6Type kind", - [ConvertFP6E2M3, ConvertFP6E3M2]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::NVVM"; -} -def ConvertFP6TypeAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { let summary = "Convert a pair of float inputs to f6x2"; let description = [{ @@ -1290,19 +1279,20 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins - ConvertFP6TypeAttr:$type, F32:$a, F32:$b, - DefaultValuedAttr:$relu); - let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)"; + DefaultValuedAttr:$relu, + TypeAttr:$dstTy); + let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`"; + let hasVerifier = 1; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP6Type, + static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu); }]; string llvmBuilder = [{ - auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($type, $relu); + auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($dstTy, $relu); llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); if(op.getDst().getType().isInteger(16)) $dst = packedI16; @@ -1312,19 +1302,6 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { }]; } -def ConvertFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">; -def ConvertFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">; -def ConvertFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">; - -def ConvertFP8Type : I32EnumAttr<"ConvertFP8Type", "NVVM ConvertFP8Type kind", - [ConvertFP8E4M3, ConvertFP8E5M2, ConvertFP8UE8M0]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::NVVM"; -} -def ConvertFP8TypeAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> { let summary = "Convert a pair of float inputs to f8x2"; let description = [{ @@ -1346,23 +1323,23 @@ def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> { let hasVerifier = 1; let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins - ConvertFP8TypeAttr:$type, F32:$a, F32:$b, DefaultValuedAttr:$rnd, DefaultValuedAttr:$sat, - DefaultValuedAttr:$relu); - let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)"; + DefaultValuedAttr:$relu, + TypeAttr:$dstTy); + let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`"; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to, + static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat, bool hasRelu); }]; string llvmBuilder = [{ - auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu); + auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu); llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); if(op.getDst().getType().isInteger(16)) $dst = packedI16; @@ -1394,18 +1371,18 @@ def NVVM_ConvertF16x2ToF8x2Op : NVVM_Op<"convert.f16x2.to.f8x2"> { let hasVerifier = 1; let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins - ConvertFP8TypeAttr:$type, VectorOfLengthAndType<[2], [F16]>:$a, - DefaultValuedAttr:$relu); - let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)"; + DefaultValuedAttr:$relu, + TypeAttr:$dstTy); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`"; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to, + static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu); }]; string llvmBuilder = [{ - auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($type, $relu); + auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($dstTy, $relu); llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a}); if(op.getDst().getType().isInteger(16)) $dst = packedI16; @@ -1437,11 +1414,11 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> { let hasVerifier = 1; let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins - ConvertFP8TypeAttr:$type, VectorOfLengthAndType<[2], [BF16]>:$a, DefaultValuedAttr:$rnd, - DefaultValuedAttr:$sat); - let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)"; + DefaultValuedAttr:$sat, + TypeAttr:$dstTy); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`"; let extraClassDeclaration = [{ static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd, diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 77ec1ebde3109..a04741e0b5ab2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -189,6 +189,18 @@ LogicalResult ConvertFloatToTF32Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF6x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa(getDstTy())) { + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f32x2 to f6x2."; + } + return success(); +} + LogicalResult ConvertF32x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; using SatMode = NVVM::SaturationMode; @@ -200,41 +212,67 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { bool hasRelu = getRelu(); - switch (getType()) { - case ConvertFP8Type::E4M3: - case ConvertFP8Type::E5M2: - if (!isRoundingModeRN) - return emitOpError("Only RN rounding mode is supported for conversions " - "from f32x2 to .e4m3x2 or .e5m2x2 types"); - if (!isSatFinite) - return emitOpError("Only SATFINITE saturation mode is supported for " - "conversions from f32x2 to .e4m3x2 or .e5m2x2 types"); - break; - case ConvertFP8Type::UE8M0: - if (!(isRoundingModeRZ || isRoundingModeRP)) - return emitOpError("Only RZ or RP rounding modes are supported for " - "conversions from f32x2 to .ue8m0x2 type"); - if (hasRelu) - return emitOpError("relu not supported for conversions to .ue8m0x2 type"); - break; - } - return success(); + mlir::MLIRContext *ctx = getContext(); + + return llvm::TypeSwitch(getDstTy()) + .Case( + [&](mlir::Type) -> LogicalResult { + if (!isRoundingModeRN) { + return emitOpError("Only RN rounding mode is supported for " + "conversions from f32x2 to ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; + } + if (!isSatFinite) { + return emitOpError("Only SATFINITE saturation mode is supported " + "for conversions " + "from f32x2 to ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; + } + return success(); + }) + .Case([&](mlir::Type) -> LogicalResult { + if (!(isRoundingModeRZ || isRoundingModeRP)) { + return emitOpError("Only RZ and RP rounding modes are supported for " + "conversions from f32x2 to ") + << mlir::Float8E8M0FNUType::get(ctx) << " type"; + } + if (hasRelu) { + return emitOpError("relu not supported for conversions to ") + << mlir::Float8E8M0FNUType::get(ctx) << " type"; + } + return success(); + }) + .Default([&](mlir::Type) { + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << ", " + << mlir::Float8E5M2Type::get(ctx) << ", and " + << mlir::Float8E8M0FNUType::get(ctx) + << " types are " + "supported for conversions from f32x2 to f8x2"; + }); } LogicalResult ConvertF16x2ToF8x2Op::verify() { - if (getType() == ConvertFP8Type::UE8M0) - return emitOpError("Only .e4m3 or .e5m2 types are supported for " - "conversions from f16x2 to f8x2."); + mlir::MLIRContext *ctx = getContext(); + if (!llvm::isa(getDstTy())) { + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f16x2 to f8x2."; + } return success(); } LogicalResult ConvertBF16x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; - if (getType() != ConvertFP8Type::UE8M0) - return emitOpError( - "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2."); + if (!llvm::isa(getDstTy())) + return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext()) + << " type is supported for conversions from " + "bf16x2 to f8x2."; auto rnd = getRnd(); if (!(rnd == RndMode::RZ || rnd == RndMode::RP)) @@ -1714,15 +1752,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite -llvm::Intrinsic::ID -ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) { - switch (type) { - case NVVM::ConvertFP6Type::E2M3: - return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); - case NVVM::ConvertFP6Type::E3M2: - return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); - } - llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op"); +llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float6E2M3FNType) { + return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); + }) + .Case([&](mlir::Float6E3M2FNType) { + return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \ @@ -1734,41 +1776,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) { : llvm::Intrinsic::nvvm_ff_to_##type##_rn llvm::Intrinsic::ID -ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, - NVVM::FPRoundingMode rnd, +ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat, bool hasRelu) { bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ); bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP); - switch (type) { - case NVVM::ConvertFP8Type::E4M3: - return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); - case NVVM::ConvertFP8Type::E5M2: - return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); - case NVVM::ConvertFP8Type::UE8M0: - if (hasRoundingModeRZ) - return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); - else if (hasRoundingModeRP) - return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); - } - llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op"); + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float8E4M3FNType) { + return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); + }) + .Case([&](mlir::Float8E5M2Type) { + return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); + }) + .Case([&](mlir::Float8E8M0FNUType) { + if (hasRoundingModeRZ) + return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); + else if (hasRoundingModeRP) + return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); + + llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_F16x2_TO_F8X2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \ : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn -llvm::Intrinsic::ID -ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) { - switch (type) { - case NVVM::ConvertFP8Type::E4M3: - return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); - case NVVM::ConvertFP8Type::E5M2: - return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); - default: - llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op"); - } +llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float8E4M3FNType) { + return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); + }) + .Case([&](mlir::Float8E5M2Type) { + return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \ diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir index 04163b578aa02..99289923b58b1 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir @@ -3,9 +3,9 @@ // CHECK-LABEL: @convert_f32x2_to_fp6x2_packed llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) { //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) - %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 + %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN) //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 + %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN) llvm.return } @@ -13,9 +13,9 @@ llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) { llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) { //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8> - %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> + %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E2M3FN) //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> - %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> + %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN) llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir index 4a15efb9e805c..de21826445afb 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir @@ -5,31 +5,31 @@ // CHECK-LABEL: @convert_f32x2_to_f8x2_e4m3 llvm.func @convert_f32x2_to_f8x2_e4m3(%srcA : f32, %srcB : f32) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}}) - %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) llvm.return } // CHECK-LABEL: @convert_f32x2_to_f8x2_e5m2 llvm.func @convert_f32x2_to_f8x2_e5m2(%srcA : f32, %srcB : f32) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}}) - %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) llvm.return } // CHECK-LABEL: @convert_f32x2_to_f8x2_ue8m0 llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}}) - %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 + %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 + %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz.satfinite(float %{{.*}}, float %{{.*}}) - %res3 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res3 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp.satfinite(float %{{.*}}, float %{{.*}}) - %res4 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res4 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E8M0FNU) llvm.return } @@ -37,10 +37,10 @@ llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) { llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) { // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> - %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> + %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> (f8E4M3FN) // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> - %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> + %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> (f8E4M3FN) llvm.return } @@ -49,18 +49,18 @@ llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) { // CHECK-LABEL: @convert_f16x2_to_f8x2_e4m3 llvm.func @convert_f16x2_to_f8x2_e4m3(%src : vector<2xf16>) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}}) - %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 + %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E4M3FN) // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}}) - %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 + %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 (f8E4M3FN) llvm.return } // CHECK-LABEL: @convert_f16x2_to_f8x2_e5m2 llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}}) - %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 + %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E5M2) // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}}) - %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 + %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 (f8E5M2) llvm.return } @@ -68,10 +68,10 @@ llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) { llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) { // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> - %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> + %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN) // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> - %res2 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> + %res2 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> (f8E5M2) llvm.return } @@ -80,13 +80,13 @@ llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) { // CHECK-LABEL: @convert_bf16x2_to_f8x2_ue8m0 llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}}) - %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}}) - %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz.satfinite(<2 x bfloat> %{{.*}}) - %res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 + %res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}}) - %res4 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 + %res4 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) llvm.return } @@ -94,9 +94,9 @@ llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) { llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) { // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> - %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> vector<2xi8> + %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU) // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> - %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> + %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU) llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index b35a6dbcca286..15ab66d6c511e 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -175,64 +175,64 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) { // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) { - // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) { - // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) { - // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to .ue8m0x2 type}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : i16 + // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from f32x2 to 'f8E8M0FNU' type}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : i16 (f8E8M0FNU) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) { - // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) { - // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) { - // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, relu = true} : i16 + // expected-error @below {{relu not supported for conversions to 'f8E8M0FNU' type}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, relu = true} : i16 (f8E8M0FNU) llvm.return } // ----- llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) { - // expected-error @below {{Only .e4m3 or .e5m2 types are supported for conversions from f16x2 to f8x2.}} - %res = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 + // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f16x2 to f8x2.}} + %res = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E8M0FNU) llvm.return } // ----- llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) { - // expected-error @below {{Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.}} - %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from bf16x2 to f8x2.}} + %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E4M3FN) llvm.return } @@ -240,7 +240,15 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) { llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) { // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}} - %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) { + // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x2 to f6x2.}} + %res = nvvm.convert.f32x2.to.f6x2 %a, %b : i16 (f8E8M0FNU) llvm.return }