Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions clang/include/clang/Basic/BuiltinsNVPTX.td
Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,35 @@ def __nvvm_ff2bf16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)
def __nvvm_ff2bf16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rs :
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2bf16x2_rs_relu :
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2bf16x2_rs_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2bf16x2_rs_relu_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_ff2f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rs :
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2f16x2_rs_relu :
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2f16x2_rs_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_ff2f16x2_rs_relu_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_f2bf16_rn : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
def __nvvm_f2bf16_rn_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
Expand Down Expand Up @@ -616,6 +640,19 @@ def __nvvm_e4m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
def __nvvm_e5m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
def __nvvm_e5m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;

def __nvvm_f32x4_to_e4m3x4_rs_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e5m2x4_rs_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_ff_to_e2m3x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_e2m3x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_e3m2x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
Expand All @@ -626,12 +663,32 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

def __nvvm_f32x4_to_e2m3x4_rs_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e3m2x4_rs_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite :
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

def __nvvm_f32x4_to_e2m1x4_rs_satfinite :
NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
def __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite :
NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;

def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
Expand Down
123 changes: 123 additions & 0 deletions clang/test/CodeGen/builtins-nvptx.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx86 -DPTX=86 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX86_SM120a %s
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_103a -target-feature +ptx87 -DPTX=87 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM103a %s
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_100a -target-feature +ptx87 -DPTX=87 \
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM100a %s
// ### The last run to check with the highest SM and PTX version available
// ### to make sure target builtins are still accepted.
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx87 -DPTX=87 \
Expand Down Expand Up @@ -1203,6 +1209,123 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
// CHECK: ret void
}

__device__ void nvvm_cvt_sm100a_sm103a() {
#if (PTX >= 87) && (__CUDA_ARCH_FEAT_SM100_ALL || __CUDA_ARCH_FEAT_SM103_ALL)

typedef __fp16 f16x2 __attribute__((ext_vector_type(2)));
typedef __bf16 bf16x2 __attribute__((ext_vector_type(2)));
typedef char uint8x4 __attribute__((ext_vector_type(4)));

// CHECK_PTX87_SM100a: %[[R1:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM100a: store <2 x half> %[[R1]], ptr %r1
// CHECK_PTX87_SM103a: %[[R1:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: store <2 x half> %[[R1]], ptr %r1
f16x2 r1 = __nvvm_ff2f16x2_rs(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: %[[R2:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM100a: store <2 x half> %[[R2]], ptr %r2
// CHECK_PTX87_SM103a: %[[R2:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: store <2 x half> %[[R2]], ptr %r2
f16x2 r2 = __nvvm_ff2f16x2_rs_relu(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: %[[R3:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM100a: store <2 x half> %[[R3]], ptr %r3
// CHECK_PTX87_SM103a: %[[R3:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: store <2 x half> %[[R3]], ptr %r3
f16x2 r3 = __nvvm_ff2f16x2_rs_satfinite(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: %[[R4:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM100a: store <2 x half> %[[R4]], ptr %r4
// CHECK_PTX87_SM103a: %[[R4:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: store <2 x half> %[[R4]], ptr
f16x2 r4 = __nvvm_ff2f16x2_rs_relu_satfinite(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: %[[R5:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM100a: store <2 x bfloat> %[[R5]], ptr %r5
// CHECK_PTX87_SM103a: %[[R5:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: store <2 x bfloat> %[[R5]], ptr %r5
bf16x2 r5 = __nvvm_ff2bf16x2_rs(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: %[[R6:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM100a: store <2 x bfloat> %[[R6]], ptr %r6
// CHECK_PTX87_SM103a: %[[R6:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: store <2 x bfloat> %[[R6]], ptr %r6
bf16x2 r6 = __nvvm_ff2bf16x2_rs_relu(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: %[[R7:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM100a: store <2 x bfloat> %[[R7]], ptr %r7
// CHECK_PTX87_SM103a: %[[R7:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: store <2 x bfloat> %[[R7]], ptr %r7
bf16x2 r7 = __nvvm_ff2bf16x2_rs_satfinite(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: %[[R8:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM100a: store <2 x bfloat> %[[R8]], ptr %r8
// CHECK_PTX87_SM103a: %[[R8:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
// CHECK_PTX87_SM103a: store <2 x bfloat> %[[R8]], ptr %r8
bf16x2 r8 = __nvvm_ff2bf16x2_rs_relu_satfinite(1.0f, 1.0f, 0);

// CHECK_PTX87_SM100a: %[[R9:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store <4 x i8> %[[R9]], ptr %r9
// CHECK_PTX87_SM103a: %[[R9:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store <4 x i8> %[[R9]], ptr %r9
uint8x4 r9 = __nvvm_f32x4_to_e4m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: %[[R10:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store <4 x i8> %[[R10]], ptr %r10
// CHECK_PTX87_SM103a: %[[R10:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store <4 x i8> %[[R10]], ptr %r10
uint8x4 r10 = __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: %[[R11:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store <4 x i8> %[[R11]], ptr %r11
// CHECK_PTX87_SM103a: %[[R11:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store <4 x i8> %[[R11]], ptr %r11
uint8x4 r11 = __nvvm_f32x4_to_e5m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: %[[R12:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store <4 x i8> %[[R12]], ptr %r12
// CHECK_PTX87_SM103a: %[[R12:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store <4 x i8> %[[R12]], ptr %r12
uint8x4 r12 = __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: %[[R13:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store <4 x i8> %[[R13]], ptr %r13
// CHECK_PTX87_SM103a: %[[R13:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store <4 x i8> %[[R13]], ptr %r13
uint8x4 r13 = __nvvm_f32x4_to_e2m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: %[[R14:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store <4 x i8> %[[R14]], ptr %r14
// CHECK_PTX87_SM103a: %[[R14:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store <4 x i8> %[[R14]], ptr %r14
uint8x4 r14 = __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: %[[R15:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store <4 x i8> %[[R15]], ptr %r15
// CHECK_PTX87_SM103a: %[[R15:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store <4 x i8> %[[R15]], ptr %r15
uint8x4 r15 = __nvvm_f32x4_to_e3m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: %[[R16:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store <4 x i8> %[[R16]], ptr %r16
// CHECK_PTX87_SM103a: %[[R16:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store <4 x i8> %[[R16]], ptr %r16
uint8x4 r16 = __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: %[[R17:.*]] = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store i16 %[[R17]], ptr %r17
// CHECK_PTX87_SM103a: %[[R17:.*]] = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store i16 %[[R17]], ptr %r17
short r17 = __nvvm_f32x4_to_e2m1x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);

// CHECK_PTX87_SM100a: %[[R18:.*]] = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM100a: store i16 %[[R18]], ptr %r18
// CHECK_PTX87_SM103a: %[[R18:.*]] = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
// CHECK_PTX87_SM103a: store i16 %[[R18]], ptr %r18
short r18 = __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
#endif
}

#define NAN32 0x7FBFFFFF
#define NAN16 (__bf16)0x7FBF
#define BF16 (__bf16)0.1f
Expand Down
37 changes: 37 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,18 @@ let TargetPrefix = "nvvm" in {
}
}

// RS rounding mode (Stochastic Rounding) conversions for f16x2, bf16x2 types
// The last i32 operand provides the random bits for the conversion
foreach relu = ["", "_relu"] in {
foreach satfinite = ["", "_satfinite"] in {
def int_nvvm_ff2f16x2_rs # relu # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;

def int_nvvm_ff2bf16x2_rs # relu # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
}
}

foreach satfinite = ["", "_satfinite"] in {
def int_nvvm_f2tf32_rna # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_i32_ty], [llvm_float_ty]>;
Expand All @@ -1515,6 +1527,15 @@ let TargetPrefix = "nvvm" in {
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
}
}

// RS rounding mode (Stochastic Rounding) conversions for f8x4 types
// The last i32 operand provides the random bits for the conversion
foreach type = ["e4m3x4", "e5m2x4"] in {
foreach relu = ["", "_relu"] in {
def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
}
}

// FP4 conversions.
foreach relu = ["", "_relu"] in {
Expand All @@ -1524,6 +1545,13 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
}

// RS rounding mode (Stochastic Rounding) conversions for f4x4 type
// The last i32 operand provides the random bits for the conversion
foreach relu = ["", "_relu"] in {
def int_nvvm_f32x4_to_e2m1x4_rs # relu # _satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_i16_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
}

// FP6 conversions.
foreach type = ["e2m3x2", "e3m2x2"] in {
Expand All @@ -1535,6 +1563,15 @@ let TargetPrefix = "nvvm" in {
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
}
}

// RS rounding mode (Stochastic Rounding) conversions for f6x4 types
// The last i32 operand provides the random bits for the conversion
foreach type = ["e2m3x4", "e3m2x4"] in {
foreach relu = ["", "_relu"] in {
def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
}
}

// UE8M0x2 conversions.
foreach rmode = ["_rz", "_rp"] in {
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
case NVPTX::PTXCvtMode::RNA:
O << ".rna";
return;
case NVPTX::PTXCvtMode::RS:
O << ".rs";
return;
}
}
llvm_unreachable("Invalid conversion modifier");
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ enum CvtMode {
RM,
RP,
RNA,
RS,

BASE_MASK = 0x0F,
FTZ_FLAG = 0x10,
Expand Down
Loading