-
Couldn't load subscription status.
- Fork 15k
[clang][NVPTX] Add intrinsics and builtins for CVT RS rounding mode #160494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-clang Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds LLVM intrinsics and clang builtins for the Tests are added in Patch is 44.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160494.diff 10 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index 2d6fa1771014d..819262d87a917 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -579,11 +579,19 @@ def __nvvm_ff2bf16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)
def __nvvm_ff2bf16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
+def __nvvm_ff2bf16x2_rs : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2bf16x2_rs_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2bf16x2_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2bf16x2_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
+def __nvvm_ff2f16x2_rs : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2f16x2_rs_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2f16x2_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_ff2f16x2_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f2bf16_rn : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
def __nvvm_f2bf16_rn_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
@@ -616,6 +624,11 @@ def __nvvm_e4m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
def __nvvm_e5m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
def __nvvm_e5m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
+def __nvvm_f32x4_to_e4m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e5m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+
def __nvvm_ff_to_e2m3x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_e2m3x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_e3m2x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
@@ -626,12 +639,20 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+def __nvvm_f32x4_to_e2m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e3m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+
def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+def __nvvm_f32x4_to_e2m1x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+def __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
+
def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index f994adb14e457..0cf116ea5c5b4 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -43,6 +43,12 @@
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx86 -DPTX=86 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX86_SM120a %s
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_103a -target-feature +ptx87 -DPTX=87 \
+// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM103a %s
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_100a -target-feature +ptx87 -DPTX=87 \
+// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
+// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM100a %s
// ### The last run to check with the highest SM and PTX version available
// ### to make sure target builtins are still accepted.
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx87 -DPTX=87 \
@@ -1203,6 +1209,83 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
// CHECK: ret void
}
+__device__ void nvvm_cvt_sm100a_sm103a() {
+#if (PTX >= 87) && (__CUDA_ARCH_FEAT_SM100_ALL || __CUDA_ARCH_FEAT_SM103_ALL)
+
+// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
+ __nvvm_ff2f16x2_rs(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
+ __nvvm_ff2f16x2_rs_relu(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+ __nvvm_ff2f16x2_rs_satfinite(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+ __nvvm_ff2f16x2_rs_relu_satfinite(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
+ __nvvm_ff2bf16x2_rs(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
+ __nvvm_ff2bf16x2_rs_relu(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+ __nvvm_ff2bf16x2_rs_satfinite(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
+ __nvvm_ff2bf16x2_rs_relu_satfinite(1.0f, 1.0f, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e4m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e5m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e2m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e3m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e2m1x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+
+// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
+ __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
+#endif
+}
+
#define NAN32 0x7FBFFFFF
#define NAN16 (__bf16)0x7FBF
#define BF16 (__bf16)0.1f
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7b40841e45d0d..78aedb99487cd 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1421,6 +1421,18 @@ let TargetPrefix = "nvvm" in {
}
}
+ // RS rounding mode (Stochastic Rounding) conversions for f16x2, bf16x2 types
+ // The last i32 operand provides the random bits for the conversion
+ foreach relu = ["", "_relu"] in {
+ foreach satfinite = ["", "_satfinite"] in {
+ def int_nvvm_ff2f16x2_rs # relu # satfinite : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
+
+ def int_nvvm_ff2bf16x2_rs # relu # satfinite : NVVMBuiltin,
+ PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
+ }
+ }
+
foreach satfinite = ["", "_satfinite"] in {
def int_nvvm_f2tf32_rna # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_i32_ty], [llvm_float_ty]>;
@@ -1443,6 +1455,15 @@ let TargetPrefix = "nvvm" in {
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
}
}
+
+ // RS rounding mode (Stochastic Rounding) conversions for f8x4 types
+ // The last i32 operand provides the random bits for the conversion
+ foreach type = ["e4m3x4", "e5m2x4"] in {
+ foreach relu = ["", "_relu"] in {
+ def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
+ PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
+ }
+ }
// FP4 conversions.
foreach relu = ["", "_relu"] in {
@@ -1452,6 +1473,13 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
}
+
+ // RS rounding mode (Stochastic Rounding) conversions for f4x4 type
+ // The last i32 operand provides the random bits for the conversion
+ foreach relu = ["", "_relu"] in {
+ def int_nvvm_f32x4_to_e2m1x4_rs # relu # _satfinite : NVVMBuiltin,
+ PureIntrinsic<[llvm_i16_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
+ }
// FP6 conversions.
foreach type = ["e2m3x2", "e3m2x2"] in {
@@ -1463,6 +1491,15 @@ let TargetPrefix = "nvvm" in {
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
}
}
+
+ // RS rounding mode (Stochastic Rounding) conversions for f6x4 types
+ // The last i32 operand provides the random bits for the conversion
+ foreach type = ["e2m3x4", "e3m2x4"] in {
+ foreach relu = ["", "_relu"] in {
+ def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
+ PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
+ }
+ }
// UE8M0x2 conversions.
foreach rmode = ["_rz", "_rp"] in {
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09935330..77913f27838e2 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -149,6 +149,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
case NVPTX::PTXCvtMode::RNA:
O << ".rna";
return;
+ case NVPTX::PTXCvtMode::RS:
+ O << ".rs";
+ return;
}
}
llvm_unreachable("Invalid conversion modifier");
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 77a0e03d4075a..1e0f747f8f7fc 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -207,6 +207,7 @@ enum CvtMode {
RM,
RP,
RNA,
+ RS,
BASE_MASK = 0x0F,
FTZ_FLAG = 0x10,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index ca8a3f69f991d..05ada362ab946 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1077,9 +1077,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Enable custom lowering for the following:
// * MVT::i128 - clusterlaunchcontrol
// * MVT::i32 - prmt
+ // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
// * MVT::Other - internal.addrspace.wrap
- setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
- Custom);
+ setOperationAction(ISD::INTRINSIC_WO_CHAIN,
+ {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
}
const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -1134,6 +1135,11 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X)
MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y)
MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z)
+ MAKE_CASE(NVPTXISD::CVT_E4M3X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E5M2X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E2M3X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E3M2X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E2M1X4_F32X4_RS_SF)
}
return nullptr;
@@ -2693,6 +2699,69 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
{TryCancelResponse0, TryCancelResponse1});
}
+bool isCvtRSReluIntrinsic(Intrinsic::ID ID) {
+ switch (ID) {
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
+ return true;
+ default:
+ return false;
+ }
+}
+
+static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+ SDLoc DL(N);
+ SDValue F32Vec = N->getOperand(1);
+ SDValue RBits = N->getOperand(2);
+
+ unsigned IntrinsicID = N->getConstantOperandVal(0);
+
+ uint32_t CvtModeFlag = NVPTX::PTXCvtMode::CvtMode::RS;
+ if (isCvtRSReluIntrinsic(IntrinsicID))
+ CvtModeFlag |= NVPTX::PTXCvtMode::CvtMode::RELU_FLAG;
+
+ SDValue Float1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+ DAG.getIntPtrConstant(0, DL));
+ SDValue Float2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+ DAG.getIntPtrConstant(1, DL));
+ SDValue Float3 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+ DAG.getIntPtrConstant(2, DL));
+ SDValue Float4 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+ DAG.getIntPtrConstant(3, DL));
+
+ auto OpSignature =
+ [&]() -> std::pair<NVPTXISD::NodeType, MVT::SimpleValueType> {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
+ return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8};
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
+ return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8};
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
+ return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8};
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
+ return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8};
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfin...
[truncated]
|
c9950c7 to
23f472a
Compare
This change adds LLVM intrinsics and clang builtins for the `cvt` RS rounding mode instruction variants. Tests are added in `convert-sm103a.ll` and verified through ptxas-13.0.
23f472a to
5bd9195
Compare
|
LGTM, with one optional nit. @Artem-B , Please help us with a review on the Clang side of things. |
c5ae79b to
d1f60de
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, modulo few style and test nits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/127/builds/4906 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/3/builds/22957 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/63/builds/11049 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/186/builds/12952 Here is the relevant piece of the build log for the reference |
…lvm#160494) This change adds LLVM intrinsics and clang builtins for the `cvt` RS rounding mode instruction variants. Tests are added in `convert-sm103a.ll` and verified through ptxas-13.0.
This change adds LLVM intrinsics and clang builtins for the
cvtRS rounding mode instruction variants.
Tests are added in
convert-sm103a.lland verified through ptxas-13.0.