Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions clang/include/clang/Basic/BuiltinsNVPTX.td
Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,19 @@ def __nvvm_ff2bf16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)
def __nvvm_ff2bf16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rs : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2bf16x2_rs_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2bf16x2_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2bf16x2_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_ff2f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rs : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2f16x2_rs_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2f16x2_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2f16x2_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_f2bf16_rn : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
def __nvvm_f2bf16_rn_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
Expand Down Expand Up @@ -616,6 +624,11 @@ def __nvvm_e4m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
def __nvvm_e5m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
def __nvvm_e5m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;

def __nvvm_f32x4_to_e4m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e5m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_ff_to_e2m3x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_e2m3x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_e3m2x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
Expand All @@ -626,12 +639,20 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

def __nvvm_f32x4_to_e2m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e3m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

def __nvvm_f32x4_to_e2m1x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
Expand Down
83 changes: 83 additions & 0 deletions clang/test/CodeGen/builtins-nvptx.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx86 -DPTX=86 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX86_SM120a %s
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_103a -target-feature +ptx87 -DPTX=87 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM103a %s
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_100a -target-feature +ptx87 -DPTX=87 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM100a %s
// ### The last run to check with the highest SM and PTX version available
// ### to make sure target builtins are still accepted.
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx87 -DPTX=87 \
Expand Down Expand Up @@ -1203,6 +1209,83 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
// CHECK: ret void
}

__device__ void nvvm_cvt_sm100a_sm103a() {
#if (PTX >= 87) && (__CUDA_ARCH_FEAT_SM100_ALL || __CUDA_ARCH_FEAT_SM103_ALL)

// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
__nvvm_ff2f16x2_rs(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
__nvvm_ff2f16x2_rs_relu(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
__nvvm_ff2f16x2_rs_satfinite(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
__nvvm_ff2f16x2_rs_relu_satfinite(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
__nvvm_ff2bf16x2_rs(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
__nvvm_ff2bf16x2_rs_relu(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
__nvvm_ff2bf16x2_rs_satfinite(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
__nvvm_ff2bf16x2_rs_relu_satfinite(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e4m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e4m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e5m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e5m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e2m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e2m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e3m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e3m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e2m1x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
__nvvm_f32x4_to_e2m1x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
#endif
}

#define NAN32 0x7FBFFFFF
#define NAN16 (__bf16)0x7FBF
#define BF16 (__bf16)0.1f
Expand Down
37 changes: 37 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,18 @@ let TargetPrefix = "nvvm" in {
}
}

// RS rounding mode (Stochastic Rounding) conversions for f16x2, bf16x2 types
// The last i32 operand provides the random bits for the conversion
foreach relu = ["", "_relu"] in {
foreach satfinite = ["", "_satfinite"] in {
def int_nvvm_ff2f16x2_rs # relu # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;

def int_nvvm_ff2bf16x2_rs # relu # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
}
}

foreach satfinite = ["", "_satfinite"] in {
def int_nvvm_f2tf32_rna # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_i32_ty], [llvm_float_ty]>;
Expand All @@ -1515,6 +1527,15 @@ let TargetPrefix = "nvvm" in {
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
}
}

// RS rounding mode (Stochastic Rounding) conversions for f8x4 types
// The last i32 operand provides the random bits for the conversion
foreach type = ["e4m3x4", "e5m2x4"] in {
foreach relu = ["", "_relu"] in {
def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
}
}

// FP4 conversions.
foreach relu = ["", "_relu"] in {
Expand All @@ -1524,6 +1545,13 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
}

// RS rounding mode (Stochastic Rounding) conversions for f4x4 type
// The last i32 operand provides the random bits for the conversion
foreach relu = ["", "_relu"] in {
def int_nvvm_f32x4_to_e2m1x4_rs # relu # _satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_i16_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
}

// FP6 conversions.
foreach type = ["e2m3x2", "e3m2x2"] in {
Expand All @@ -1535,6 +1563,15 @@ let TargetPrefix = "nvvm" in {
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
}
}

// RS rounding mode (Stochastic Rounding) conversions for f6x4 types
// The last i32 operand provides the random bits for the conversion
foreach type = ["e2m3x4", "e3m2x4"] in {
foreach relu = ["", "_relu"] in {
def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
}
}

// UE8M0x2 conversions.
foreach rmode = ["_rz", "_rp"] in {
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
case NVPTX::PTXCvtMode::RNA:
O << ".rna";
return;
case NVPTX::PTXCvtMode::RS:
O << ".rs";
return;
}
}
llvm_unreachable("Invalid conversion modifier");
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ enum CvtMode {
RM,
RP,
RNA,
RS,

BASE_MASK = 0x0F,
FTZ_FLAG = 0x10,
Expand Down
77 changes: 75 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1077,9 +1077,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Enable custom lowering for the following:
// * MVT::i128 - clusterlaunchcontrol
// * MVT::i32 - prmt
// * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
// * MVT::Other - internal.addrspace.wrap
setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
Custom);
setOperationAction(ISD::INTRINSIC_WO_CHAIN,
{MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
}

const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
Expand Down Expand Up @@ -1162,6 +1163,11 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT)
MAKE_CASE(
NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT)
MAKE_CASE(NVPTXISD::CVT_E4M3X4_F32X4_RS_SF)
MAKE_CASE(NVPTXISD::CVT_E5M2X4_F32X4_RS_SF)
MAKE_CASE(NVPTXISD::CVT_E2M3X4_F32X4_RS_SF)
MAKE_CASE(NVPTXISD::CVT_E3M2X4_F32X4_RS_SF)
MAKE_CASE(NVPTXISD::CVT_E2M1X4_F32X4_RS_SF)
}
return nullptr;

Expand Down Expand Up @@ -2839,6 +2845,62 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
{TryCancelResponse0, TryCancelResponse1});
}

static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
SDNode *N = Op.getNode();
SDLoc DL(N);
SDValue F32Vec = N->getOperand(1);
SDValue RBits = N->getOperand(2);

unsigned IntrinsicID = N->getConstantOperandVal(0);

// Extract the 4 float elements from the vector
SmallVector<SDValue, 6> Ops;
for (unsigned i = 0; i < 4; ++i) {
Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
DAG.getIntPtrConstant(i, DL)));
}

using NVPTX::PTXCvtMode::CvtMode;

auto [OpCode, RetTy, CvtModeFlag] =
[&]() -> std::tuple<NVPTXISD::NodeType, MVT::SimpleValueType, uint32_t> {
switch (IntrinsicID) {
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
CvtMode::RS | CvtMode::RELU_FLAG};
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
CvtMode::RS | CvtMode::RELU_FLAG};
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
CvtMode::RS | CvtMode::RELU_FLAG};
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
CvtMode::RS | CvtMode::RELU_FLAG};
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
CvtMode::RS | CvtMode::RELU_FLAG};
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
default:
llvm_unreachable("unsupported/unhandled intrinsic");
}
}();

Ops.push_back(RBits);
Ops.push_back(DAG.getConstant(CvtModeFlag, DL, MVT::i32));

return DAG.getNode(OpCode, DL, RetTy, Ops);
}

static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
const unsigned Mode = [&]() {
switch (Op->getConstantOperandVal(0)) {
Expand Down Expand Up @@ -2886,6 +2948,17 @@ static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
return LowerClusterLaunchControlQueryCancel(Op, DAG);
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
return lowerCvtRSIntrinsics(Op, DAG);
}
}

Expand Down
Loading
Loading