Skip to content

Commit 3f63c13

Browse files
Wolfram70aokblast
authored andcommitted
[clang][NVPTX] Add intrinsics and builtins for CVT RS rounding mode (llvm#160494)
This change adds LLVM intrinsics and clang builtins for the `cvt` RS rounding mode instruction variants. Tests are added in `convert-sm103a.ll` and verified through ptxas-13.0.
1 parent 264c38c commit 3f63c13

File tree

10 files changed

+715
-5
lines changed

10 files changed

+715
-5
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.td

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,35 @@ def __nvvm_ff2bf16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)
579579
def __nvvm_ff2bf16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
580580
def __nvvm_ff2bf16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
581581
def __nvvm_ff2bf16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
582+
def __nvvm_ff2bf16x2_rs :
583+
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
584+
SM<"100a", [SM_103a]>, PTX87>;
585+
def __nvvm_ff2bf16x2_rs_relu :
586+
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
587+
SM<"100a", [SM_103a]>, PTX87>;
588+
def __nvvm_ff2bf16x2_rs_satfinite :
589+
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
590+
SM<"100a", [SM_103a]>, PTX87>;
591+
def __nvvm_ff2bf16x2_rs_relu_satfinite :
592+
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
593+
SM<"100a", [SM_103a]>, PTX87>;
582594

583595
def __nvvm_ff2f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
584596
def __nvvm_ff2f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
585597
def __nvvm_ff2f16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
586598
def __nvvm_ff2f16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
599+
def __nvvm_ff2f16x2_rs :
600+
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
601+
SM<"100a", [SM_103a]>, PTX87>;
602+
def __nvvm_ff2f16x2_rs_relu :
603+
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
604+
SM<"100a", [SM_103a]>, PTX87>;
605+
def __nvvm_ff2f16x2_rs_satfinite :
606+
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
607+
SM<"100a", [SM_103a]>, PTX87>;
608+
def __nvvm_ff2f16x2_rs_relu_satfinite :
609+
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
610+
SM<"100a", [SM_103a]>, PTX87>;
587611

588612
def __nvvm_f2bf16_rn : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
589613
def __nvvm_f2bf16_rn_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
@@ -616,6 +640,19 @@ def __nvvm_e4m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
616640
def __nvvm_e5m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
617641
def __nvvm_e5m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
618642

643+
def __nvvm_f32x4_to_e4m3x4_rs_satfinite :
644+
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
645+
SM<"100a", [SM_103a]>, PTX87>;
646+
def __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite :
647+
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
648+
SM<"100a", [SM_103a]>, PTX87>;
649+
def __nvvm_f32x4_to_e5m2x4_rs_satfinite :
650+
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
651+
SM<"100a", [SM_103a]>, PTX87>;
652+
def __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite :
653+
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
654+
SM<"100a", [SM_103a]>, PTX87>;
655+
619656
def __nvvm_ff_to_e2m3x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
620657
def __nvvm_ff_to_e2m3x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
621658
def __nvvm_ff_to_e3m2x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
@@ -626,12 +663,32 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
626663
def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
627664
def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
628665

666+
def __nvvm_f32x4_to_e2m3x4_rs_satfinite :
667+
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
668+
SM<"100a", [SM_103a]>, PTX87>;
669+
def __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite :
670+
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
671+
SM<"100a", [SM_103a]>, PTX87>;
672+
def __nvvm_f32x4_to_e3m2x4_rs_satfinite :
673+
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
674+
SM<"100a", [SM_103a]>, PTX87>;
675+
def __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite :
676+
NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)",
677+
SM<"100a", [SM_103a]>, PTX87>;
678+
629679
def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
630680
def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
631681

632682
def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
633683
def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
634684

685+
def __nvvm_f32x4_to_e2m1x4_rs_satfinite :
686+
NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)",
687+
SM<"100a", [SM_103a]>, PTX87>;
688+
def __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite :
689+
NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)",
690+
SM<"100a", [SM_103a]>, PTX87>;
691+
635692
def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
636693
def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
637694
def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@
4343
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx86 -DPTX=86 \
4444
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
4545
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX86_SM120a %s
46+
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_103a -target-feature +ptx87 -DPTX=87 \
47+
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
48+
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM103a %s
49+
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_100a -target-feature +ptx87 -DPTX=87 \
50+
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
51+
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM100a %s
4652
// ### The last run to check with the highest SM and PTX version available
4753
// ### to make sure target builtins are still accepted.
4854
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx87 -DPTX=87 \
@@ -1203,6 +1209,123 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
12031209
// CHECK: ret void
12041210
}
12051211

1212+
__device__ void nvvm_cvt_sm100a_sm103a() {
1213+
#if (PTX >= 87) && (__CUDA_ARCH_FEAT_SM100_ALL || __CUDA_ARCH_FEAT_SM103_ALL)
1214+
1215+
typedef __fp16 f16x2 __attribute__((ext_vector_type(2)));
1216+
typedef __bf16 bf16x2 __attribute__((ext_vector_type(2)));
1217+
typedef char uint8x4 __attribute__((ext_vector_type(4)));
1218+
1219+
// CHECK_PTX87_SM100a: %[[R1:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
1220+
// CHECK_PTX87_SM100a: store <2 x half> %[[R1]], ptr %r1
1221+
// CHECK_PTX87_SM103a: %[[R1:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
1222+
// CHECK_PTX87_SM103a: store <2 x half> %[[R1]], ptr %r1
1223+
f16x2 r1 = __nvvm_ff2f16x2_rs(1.0f, 1.0f, 0);
1224+
1225+
// CHECK_PTX87_SM100a: %[[R2:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
1226+
// CHECK_PTX87_SM100a: store <2 x half> %[[R2]], ptr %r2
1227+
// CHECK_PTX87_SM103a: %[[R2:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
1228+
// CHECK_PTX87_SM103a: store <2 x half> %[[R2]], ptr %r2
1229+
f16x2 r2 = __nvvm_ff2f16x2_rs_relu(1.0f, 1.0f, 0);
1230+
1231+
// CHECK_PTX87_SM100a: %[[R3:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1232+
// CHECK_PTX87_SM100a: store <2 x half> %[[R3]], ptr %r3
1233+
// CHECK_PTX87_SM103a: %[[R3:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1234+
// CHECK_PTX87_SM103a: store <2 x half> %[[R3]], ptr %r3
1235+
f16x2 r3 = __nvvm_ff2f16x2_rs_satfinite(1.0f, 1.0f, 0);
1236+
1237+
// CHECK_PTX87_SM100a: %[[R4:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1238+
// CHECK_PTX87_SM100a: store <2 x half> %[[R4]], ptr %r4
1239+
// CHECK_PTX87_SM103a: %[[R4:.*]] = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1240+
// CHECK_PTX87_SM103a: store <2 x half> %[[R4]], ptr %r4
1241+
f16x2 r4 = __nvvm_ff2f16x2_rs_relu_satfinite(1.0f, 1.0f, 0);
1242+
1243+
// CHECK_PTX87_SM100a: %[[R5:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
1244+
// CHECK_PTX87_SM100a: store <2 x bfloat> %[[R5]], ptr %r5
1245+
// CHECK_PTX87_SM103a: %[[R5:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
1246+
// CHECK_PTX87_SM103a: store <2 x bfloat> %[[R5]], ptr %r5
1247+
bf16x2 r5 = __nvvm_ff2bf16x2_rs(1.0f, 1.0f, 0);
1248+
1249+
// CHECK_PTX87_SM100a: %[[R6:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
1250+
// CHECK_PTX87_SM100a: store <2 x bfloat> %[[R6]], ptr %r6
1251+
// CHECK_PTX87_SM103a: %[[R6:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
1252+
// CHECK_PTX87_SM103a: store <2 x bfloat> %[[R6]], ptr %r6
1253+
bf16x2 r6 = __nvvm_ff2bf16x2_rs_relu(1.0f, 1.0f, 0);
1254+
1255+
// CHECK_PTX87_SM100a: %[[R7:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1256+
// CHECK_PTX87_SM100a: store <2 x bfloat> %[[R7]], ptr %r7
1257+
// CHECK_PTX87_SM103a: %[[R7:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1258+
// CHECK_PTX87_SM103a: store <2 x bfloat> %[[R7]], ptr %r7
1259+
bf16x2 r7 = __nvvm_ff2bf16x2_rs_satfinite(1.0f, 1.0f, 0);
1260+
1261+
// CHECK_PTX87_SM100a: %[[R8:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1262+
// CHECK_PTX87_SM100a: store <2 x bfloat> %[[R8]], ptr %r8
1263+
// CHECK_PTX87_SM103a: %[[R8:.*]] = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1264+
// CHECK_PTX87_SM103a: store <2 x bfloat> %[[R8]], ptr %r8
1265+
bf16x2 r8 = __nvvm_ff2bf16x2_rs_relu_satfinite(1.0f, 1.0f, 0);
1266+
1267+
// CHECK_PTX87_SM100a: %[[R9:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1268+
// CHECK_PTX87_SM100a: store <4 x i8> %[[R9]], ptr %r9
1269+
// CHECK_PTX87_SM103a: %[[R9:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1270+
// CHECK_PTX87_SM103a: store <4 x i8> %[[R9]], ptr %r9
1271+
uint8x4 r9 = __nvvm_f32x4_to_e4m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1272+
1273+
// CHECK_PTX87_SM100a: %[[R10:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1274+
// CHECK_PTX87_SM100a: store <4 x i8> %[[R10]], ptr %r10
1275+
// CHECK_PTX87_SM103a: %[[R10:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1276+
// CHECK_PTX87_SM103a: store <4 x i8> %[[R10]], ptr %r10
1277+
uint8x4 r10 = __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1278+
1279+
// CHECK_PTX87_SM100a: %[[R11:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1280+
// CHECK_PTX87_SM100a: store <4 x i8> %[[R11]], ptr %r11
1281+
// CHECK_PTX87_SM103a: %[[R11:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1282+
// CHECK_PTX87_SM103a: store <4 x i8> %[[R11]], ptr %r11
1283+
uint8x4 r11 = __nvvm_f32x4_to_e5m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1284+
1285+
// CHECK_PTX87_SM100a: %[[R12:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1286+
// CHECK_PTX87_SM100a: store <4 x i8> %[[R12]], ptr %r12
1287+
// CHECK_PTX87_SM103a: %[[R12:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1288+
// CHECK_PTX87_SM103a: store <4 x i8> %[[R12]], ptr %r12
1289+
uint8x4 r12 = __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1290+
1291+
// CHECK_PTX87_SM100a: %[[R13:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1292+
// CHECK_PTX87_SM100a: store <4 x i8> %[[R13]], ptr %r13
1293+
// CHECK_PTX87_SM103a: %[[R13:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1294+
// CHECK_PTX87_SM103a: store <4 x i8> %[[R13]], ptr %r13
1295+
uint8x4 r13 = __nvvm_f32x4_to_e2m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1296+
1297+
// CHECK_PTX87_SM100a: %[[R14:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1298+
// CHECK_PTX87_SM100a: store <4 x i8> %[[R14]], ptr %r14
1299+
// CHECK_PTX87_SM103a: %[[R14:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1300+
// CHECK_PTX87_SM103a: store <4 x i8> %[[R14]], ptr %r14
1301+
uint8x4 r14 = __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1302+
1303+
// CHECK_PTX87_SM100a: %[[R15:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1304+
// CHECK_PTX87_SM100a: store <4 x i8> %[[R15]], ptr %r15
1305+
// CHECK_PTX87_SM103a: %[[R15:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1306+
// CHECK_PTX87_SM103a: store <4 x i8> %[[R15]], ptr %r15
1307+
uint8x4 r15 = __nvvm_f32x4_to_e3m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1308+
1309+
// CHECK_PTX87_SM100a: %[[R16:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1310+
// CHECK_PTX87_SM100a: store <4 x i8> %[[R16]], ptr %r16
1311+
// CHECK_PTX87_SM103a: %[[R16:.*]] = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1312+
// CHECK_PTX87_SM103a: store <4 x i8> %[[R16]], ptr %r16
1313+
uint8x4 r16 = __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1314+
1315+
// CHECK_PTX87_SM100a: %[[R17:.*]] = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1316+
// CHECK_PTX87_SM100a: store i16 %[[R17]], ptr %r17
1317+
// CHECK_PTX87_SM103a: %[[R17:.*]] = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1318+
// CHECK_PTX87_SM103a: store i16 %[[R17]], ptr %r17
1319+
short r17 = __nvvm_f32x4_to_e2m1x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1320+
1321+
// CHECK_PTX87_SM100a: %[[R18:.*]] = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1322+
// CHECK_PTX87_SM100a: store i16 %[[R18]], ptr %r18
1323+
// CHECK_PTX87_SM103a: %[[R18:.*]] = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1324+
// CHECK_PTX87_SM103a: store i16 %[[R18]], ptr %r18
1325+
short r18 = __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
1326+
#endif
1327+
}
1328+
12061329
#define NAN32 0x7FBFFFFF
12071330
#define NAN16 (__bf16)0x7FBF
12081331
#define BF16 (__bf16)0.1f

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,18 @@ let TargetPrefix = "nvvm" in {
14931493
}
14941494
}
14951495

1496+
// RS rounding mode (Stochastic Rounding) conversions for f16x2, bf16x2 types
1497+
// The last i32 operand provides the random bits for the conversion
1498+
foreach relu = ["", "_relu"] in {
1499+
foreach satfinite = ["", "_satfinite"] in {
1500+
def int_nvvm_ff2f16x2_rs # relu # satfinite : NVVMBuiltin,
1501+
PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1502+
1503+
def int_nvvm_ff2bf16x2_rs # relu # satfinite : NVVMBuiltin,
1504+
PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1505+
}
1506+
}
1507+
14961508
foreach satfinite = ["", "_satfinite"] in {
14971509
def int_nvvm_f2tf32_rna # satfinite : NVVMBuiltin,
14981510
PureIntrinsic<[llvm_i32_ty], [llvm_float_ty]>;
@@ -1515,6 +1527,15 @@ let TargetPrefix = "nvvm" in {
15151527
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
15161528
}
15171529
}
1530+
1531+
// RS rounding mode (Stochastic Rounding) conversions for f8x4 types
1532+
// The last i32 operand provides the random bits for the conversion
1533+
foreach type = ["e4m3x4", "e5m2x4"] in {
1534+
foreach relu = ["", "_relu"] in {
1535+
def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
1536+
PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
1537+
}
1538+
}
15181539

15191540
// FP4 conversions.
15201541
foreach relu = ["", "_relu"] in {
@@ -1524,6 +1545,13 @@ let TargetPrefix = "nvvm" in {
15241545
def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin,
15251546
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
15261547
}
1548+
1549+
// RS rounding mode (Stochastic Rounding) conversions for f4x4 type
1550+
// The last i32 operand provides the random bits for the conversion
1551+
foreach relu = ["", "_relu"] in {
1552+
def int_nvvm_f32x4_to_e2m1x4_rs # relu # _satfinite : NVVMBuiltin,
1553+
PureIntrinsic<[llvm_i16_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
1554+
}
15271555

15281556
// FP6 conversions.
15291557
foreach type = ["e2m3x2", "e3m2x2"] in {
@@ -1535,6 +1563,15 @@ let TargetPrefix = "nvvm" in {
15351563
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
15361564
}
15371565
}
1566+
1567+
// RS rounding mode (Stochastic Rounding) conversions for f6x4 types
1568+
// The last i32 operand provides the random bits for the conversion
1569+
foreach type = ["e2m3x4", "e3m2x4"] in {
1570+
foreach relu = ["", "_relu"] in {
1571+
def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
1572+
PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
1573+
}
1574+
}
15381575

15391576
// UE8M0x2 conversions.
15401577
foreach rmode = ["_rz", "_rp"] in {

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
149149
case NVPTX::PTXCvtMode::RNA:
150150
O << ".rna";
151151
return;
152+
case NVPTX::PTXCvtMode::RS:
153+
O << ".rs";
154+
return;
152155
}
153156
}
154157
llvm_unreachable("Invalid conversion modifier");

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ enum CvtMode {
207207
RM,
208208
RP,
209209
RNA,
210+
RS,
210211

211212
BASE_MASK = 0x0F,
212213
FTZ_FLAG = 0x10,

0 commit comments

Comments
 (0)