diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 3e0a6987bd85b..a5d09eaa34eb5 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -147,9 +147,6 @@ LogicalResult CvtFloatToTF32Op::verify() { break; case RndMode::RN: case RndMode::RZ: - if (getSat() != NVVM::SaturationMode::NONE) - return emitError( - "Saturation mode not supported with rn/rz rounding modes."); break; default: return emitError( @@ -1221,21 +1218,26 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID( llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp"); } +#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \ + hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \ + : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf + +#define GET_CVT_F2TF32_ID(rnd, relu, sf) \ + hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \ + : CVT_F2TF32_ID_IMPL(rnd, relu, ) + llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat, bool hasRelu) { using RndMode = NVVM::FPRoundingMode; + bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); switch (rnd) { case RndMode::RN: - return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu - : llvm::Intrinsic::nvvm_f2tf32_rn; + return GET_CVT_F2TF32_ID(rn, _relu, _satfinite); case RndMode::RZ: - return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu - : llvm::Intrinsic::nvvm_f2tf32_rz; + return GET_CVT_F2TF32_ID(rz, _relu, _satfinite); case RndMode::RNA: - return (sat == NVVM::SaturationMode::SATFINITE) - ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite - : llvm::Intrinsic::nvvm_f2tf32_rna; + return GET_CVT_F2TF32_ID(rna, , _satfinite); default: llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op"); } diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir index 90a232e4baac6..ff7bad0149d4c 100644 --- a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir @@ -28,6 +28,20 @@ llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 { llvm.return %res : i32 } +// CHECK-LABEL: @convert_float_to_tf32_rn_sf +llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.satfinite(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} + llvm.return %res : i32 +} + +// CHECK-LABEL: @convert_float_to_tf32_rn_relu_sf +llvm.func @convert_float_to_tf32_rn_relu_sf(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu.satfinite(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, relu=true, sat = #nvvm.sat_mode} + llvm.return %res : i32 +} + // CHECK-LABEL: @convert_float_to_tf32_rz llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 { // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}}) @@ -41,3 +55,17 @@ llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 { %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, relu=true} llvm.return %res : i32 } + +// CHECK-LABEL: @convert_float_to_tf32_rz_sf +llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.satfinite(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} + llvm.return %res : i32 +} + +// CHECK-LABEL: @convert_float_to_tf32_rz_relu_sf +llvm.func @convert_float_to_tf32_rz_relu_sf(%src : f32) -> i32 { + // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu.satfinite(float %{{.*}}) + %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, relu=true, sat = #nvvm.sat_mode} + llvm.return %res : i32 +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index cb08064590bc3..8957377607dad 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -117,22 +117,6 @@ llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 { // ----- -llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 { - // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}} - %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} - llvm.return %res : i32 -} - -// ----- - -llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 { - // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}} - %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} - llvm.return %res : i32 -} - -// ----- - llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 { // expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}} %res = nvvm.cvt.float.to.tf32 %src