Skip to content

Commit 398398b

Browse files
committed
[clang][NVPTX] Add intrinsics and builtins for cvt RS rounding mode
This change adds LLVM intrinsics and clang builtins for the `cvt` RS rounding mode instruction variants. Tests are added in `convert-sm103a.ll` and verified through ptxas-13.0.
1 parent 70a26da commit 398398b

File tree

8 files changed

+572
-0
lines changed

8 files changed

+572
-0
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,19 @@ def __nvvm_ff2bf16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)
579579
def __nvvm_ff2bf16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
580580
def __nvvm_ff2bf16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
581581
def __nvvm_ff2bf16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
582+
def __nvvm_ff2bf16x2_rs : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
583+
def __nvvm_ff2bf16x2_rs_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
584+
def __nvvm_ff2bf16x2_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
585+
def __nvvm_ff2bf16x2_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
582586

583587
def __nvvm_ff2f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
584588
def __nvvm_ff2f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
585589
def __nvvm_ff2f16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
586590
def __nvvm_ff2f16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
591+
def __nvvm_ff2f16x2_rs : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
592+
def __nvvm_ff2f16x2_rs_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
593+
def __nvvm_ff2f16x2_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
594+
def __nvvm_ff2f16x2_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
587595

588596
def __nvvm_f2bf16_rn : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
589597
def __nvvm_f2bf16_rn_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
@@ -616,6 +624,11 @@ def __nvvm_e4m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
616624
def __nvvm_e5m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
617625
def __nvvm_e5m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
618626

627+
def __nvvm_ff_to_e4m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
628+
def __nvvm_ff_to_e4m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
629+
def __nvvm_ff_to_e5m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
630+
def __nvvm_ff_to_e5m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
631+
619632
def __nvvm_ff_to_e2m3x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
620633
def __nvvm_ff_to_e2m3x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
621634
def __nvvm_ff_to_e3m2x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
@@ -626,12 +639,20 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
626639
def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
627640
def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
628641

642+
def __nvvm_ff_to_e2m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
643+
def __nvvm_ff_to_e2m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
644+
def __nvvm_ff_to_e3m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
645+
def __nvvm_ff_to_e3m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
646+
629647
def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
630648
def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
631649

632650
def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
633651
def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
634652

653+
def __nvvm_ff_to_e2m1x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
654+
def __nvvm_ff_to_e2m1x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
655+
635656
def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
636657
def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
637658
def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@
4343
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx86 -DPTX=86 \
4444
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
4545
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX86_SM120a %s
46+
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_103a -target-feature +ptx87 -DPTX=87 \
47+
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
48+
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM103a %s
49+
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_100a -target-feature +ptx87 -DPTX=87 \
50+
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
51+
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX87_SM100a %s
4652
// ### The last run to check with the highest SM and PTX version available
4753
// ### to make sure target builtins are still accepted.
4854
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_120a -target-feature +ptx87 -DPTX=87 \
@@ -1203,6 +1209,83 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
12031209
// CHECK: ret void
12041210
}
12051211

1212+
__device__ void nvvm_cvt_sm100a_sm103a() {
1213+
#if (PTX >= 87) && (__CUDA_ARCH_FEAT_SM100_ALL || __CUDA_ARCH_FEAT_SM103_ALL)
1214+
1215+
// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
1216+
// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
1217+
__nvvm_ff2f16x2_rs(1.0f, 1.0f, 0);
1218+
1219+
// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
1220+
// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
1221+
__nvvm_ff2f16x2_rs_relu(1.0f, 1.0f, 0);
1222+
1223+
// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1224+
// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1225+
__nvvm_ff2f16x2_rs_satfinite(1.0f, 1.0f, 0);
1226+
1227+
// CHECK_PTX87_SM100a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1228+
// CHECK_PTX87_SM103a: call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1229+
__nvvm_ff2f16x2_rs_relu_satfinite(1.0f, 1.0f, 0);
1230+
1231+
// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
1232+
// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float 1.000000e+00, float 1.000000e+00, i32 0)
1233+
__nvvm_ff2bf16x2_rs(1.0f, 1.0f, 0);
1234+
1235+
// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
1236+
// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float 1.000000e+00, float 1.000000e+00, i32 0)
1237+
__nvvm_ff2bf16x2_rs_relu(1.0f, 1.0f, 0);
1238+
1239+
// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1240+
// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1241+
__nvvm_ff2bf16x2_rs_satfinite(1.0f, 1.0f, 0);
1242+
1243+
// CHECK_PTX87_SM100a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1244+
// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
1245+
__nvvm_ff2bf16x2_rs_relu_satfinite(1.0f, 1.0f, 0);
1246+
1247+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e4m3x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1248+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e4m3x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1249+
__nvvm_ff_to_e4m3x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1250+
1251+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e4m3x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1252+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e4m3x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1253+
__nvvm_ff_to_e4m3x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1254+
1255+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e5m2x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1256+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e5m2x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1257+
__nvvm_ff_to_e5m2x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1258+
1259+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e5m2x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1260+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e5m2x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1261+
__nvvm_ff_to_e5m2x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1262+
1263+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e2m3x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1264+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e2m3x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1265+
__nvvm_ff_to_e2m3x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1266+
1267+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e2m3x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1268+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e2m3x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1269+
__nvvm_ff_to_e2m3x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1270+
1271+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e3m2x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1272+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e3m2x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1273+
__nvvm_ff_to_e3m2x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1274+
1275+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e3m2x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1276+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e3m2x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1277+
__nvvm_ff_to_e3m2x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1278+
1279+
// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1280+
// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.ff.to.e2m1x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1281+
__nvvm_ff_to_e2m1x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1282+
1283+
// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1284+
// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.ff.to.e2m1x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1285+
__nvvm_ff_to_e2m1x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1286+
#endif
1287+
}
1288+
12061289
#define NAN32 0x7FBFFFFF
12071290
#define NAN16 (__bf16)0x7FBF
12081291
#define BF16 (__bf16)0.1f

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,17 @@ let TargetPrefix = "nvvm" in {
14931493
}
14941494
}
14951495

1496+
// RS rounding mode conversions for f16x2, bf16x2 types
1497+
foreach relu = ["", "_relu"] in {
1498+
foreach satfinite = ["", "_satfinite"] in {
1499+
def int_nvvm_ff2f16x2_rs # relu # satfinite : NVVMBuiltin,
1500+
PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1501+
1502+
def int_nvvm_ff2bf16x2_rs # relu # satfinite : NVVMBuiltin,
1503+
PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1504+
}
1505+
}
1506+
14961507
foreach satfinite = ["", "_satfinite"] in {
14971508
def int_nvvm_f2tf32_rna # satfinite : NVVMBuiltin,
14981509
PureIntrinsic<[llvm_i32_ty], [llvm_float_ty]>;
@@ -1515,6 +1526,14 @@ let TargetPrefix = "nvvm" in {
15151526
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
15161527
}
15171528
}
1529+
1530+
// RS rounding mode conversions for f8x4 types
1531+
foreach type = ["e4m3x4", "e5m2x4"] in {
1532+
foreach relu = ["", "_relu"] in {
1533+
def int_nvvm_ff_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
1534+
PureIntrinsic<[llvm_v4i8_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1535+
}
1536+
}
15181537

15191538
// FP4 conversions.
15201539
foreach relu = ["", "_relu"] in {
@@ -1524,6 +1543,12 @@ let TargetPrefix = "nvvm" in {
15241543
def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin,
15251544
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
15261545
}
1546+
1547+
// RS rounding mode conversions for f4x4 type
1548+
foreach relu = ["", "_relu"] in {
1549+
def int_nvvm_ff_to_e2m1x4_rs # relu # _satfinite : NVVMBuiltin,
1550+
PureIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1551+
}
15271552

15281553
// FP6 conversions.
15291554
foreach type = ["e2m3x2", "e3m2x2"] in {
@@ -1535,6 +1560,14 @@ let TargetPrefix = "nvvm" in {
15351560
PureIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty]>;
15361561
}
15371562
}
1563+
1564+
// RS rounding mode conversions for f6x4 types
1565+
foreach type = ["e2m3x4", "e3m2x4"] in {
1566+
foreach relu = ["", "_relu"] in {
1567+
def int_nvvm_ff_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
1568+
PureIntrinsic<[llvm_v4i8_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1569+
}
1570+
}
15381571

15391572
// UE8M0x2 conversions.
15401573
foreach rmode = ["_rz", "_rp"] in {

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
149149
case NVPTX::PTXCvtMode::RNA:
150150
O << ".rna";
151151
return;
152+
case NVPTX::PTXCvtMode::RS:
153+
O << ".rs";
154+
return;
152155
}
153156
}
154157
llvm_unreachable("Invalid conversion modifier");

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ enum CvtMode {
207207
RM,
208208
RP,
209209
RNA,
210+
RS,
210211

211212
BASE_MASK = 0x0F,
212213
FTZ_FLAG = 0x10,

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def CvtRZ : PatLeaf<(i32 0x6)>;
3535
def CvtRM : PatLeaf<(i32 0x7)>;
3636
def CvtRP : PatLeaf<(i32 0x8)>;
3737
def CvtRNA : PatLeaf<(i32 0x9)>;
38+
def CvtRS : PatLeaf<(i32 0xA)>;
3839

3940
def CvtNONE_FTZ : PatLeaf<(i32 0x10)>;
4041
def CvtRNI_FTZ : PatLeaf<(i32 0x11)>;
@@ -52,6 +53,7 @@ def CvtSAT_FTZ : PatLeaf<(i32 0x30)>;
5253
def CvtNONE_RELU : PatLeaf<(i32 0x40)>;
5354
def CvtRN_RELU : PatLeaf<(i32 0x45)>;
5455
def CvtRZ_RELU : PatLeaf<(i32 0x46)>;
56+
def CvtRS_RELU : PatLeaf<(i32 0x4A)>;
5557

5658
def CvtMode : Operand<i32> {
5759
let PrintMethod = "printCvtMode";
@@ -132,6 +134,9 @@ def hasSM100a : Predicate<"Subtarget->getSmVersion() == 100 && Subtarget->hasArc
132134
def hasSM101a : Predicate<"Subtarget->getSmVersion() == 101 && Subtarget->hasArchAccelFeatures()">;
133135
def hasSM120a : Predicate<"Subtarget->getSmVersion() == 120 && Subtarget->hasArchAccelFeatures()">;
134136

137+
def hasSM100aOrSM103a :
138+
Predicate<"(Subtarget->getSmVersion() == 100 || Subtarget->getSmVersion() == 103) && Subtarget->hasArchAccelFeatures()">;
139+
135140
// non-sync shfl instructions are not available on sm_70+ in PTX6.4+
136141
def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
137142
"&& Subtarget->getPTXVersion() >= 64)">;
@@ -592,6 +597,21 @@ let hasSideEffects = false in {
592597

593598
defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", B32>;
594599
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", B32>;
600+
601+
multiclass CVT_FROM_FLOAT_V2_RS<string FromName, RegisterClass RC> {
602+
def _f32_rs :
603+
BasicFlagsNVPTXInst<(outs RC:$dst),
604+
(ins B32:$src1, B32:$src2, B32:$src3), (ins CvtMode:$mode),
605+
"cvt${mode:base}${mode:relu}." # FromName # ".f32">;
606+
607+
def _f32_rs_sf :
608+
BasicFlagsNVPTXInst<(outs RC:$dst),
609+
(ins B32:$src1, B32:$src2, B32:$src3), (ins CvtMode:$mode),
610+
"cvt${mode:base}${mode:relu}.satfinite." # FromName # ".f32">;
611+
}
612+
613+
defm CVT_f16x2 : CVT_FROM_FLOAT_V2_RS<"f16x2", B32>;
614+
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_RS<"bf16x2", B32>;
595615

596616
// FP8 conversions.
597617
multiclass CVT_TO_F8X2<string F8Name> {
@@ -618,6 +638,15 @@ let hasSideEffects = false in {
618638

619639
def CVT_f16x2_e4m3x2 : CVT_f16x2_fp8<"e4m3">;
620640
def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
641+
642+
class CVT_TO_FP8X4<string F8Name> :
643+
NVPTXInst<(outs B32:$dst),
644+
(ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5,
645+
CvtMode:$mode),
646+
"cvt${mode:base}${mode:relu}.satfinite." # F8Name # "x4.f32 \t$dst, {{$src1, $src2, $src3, $src4}}, $src5;">;
647+
648+
def CVT_e4m3x4_f32_rs_sf : CVT_TO_FP8X4<"e4m3">;
649+
def CVT_e5m2x4_f32_rs_sf : CVT_TO_FP8X4<"e5m2">;
621650

622651
// Float to TF32 conversions
623652
multiclass CVT_TO_TF32<string Modifier, list<Predicate> Preds = [hasPTX<78>, hasSM<90>]> {
@@ -651,6 +680,15 @@ let hasSideEffects = false in {
651680
"cvt${mode:base}${mode:relu}.f16x2." # type>;
652681
}
653682

683+
class CVT_TO_FP6X4<string F6Name> :
684+
NVPTXInst<(outs B32:$dst),
685+
(ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5,
686+
CvtMode:$mode),
687+
"cvt${mode:base}${mode:relu}.satfinite." # F6Name # "x4.f32 \t$dst, {{$src1, $src2, $src3, $src4}}, $src5;">;
688+
689+
def CVT_e2m3x4_f32_rs_sf : CVT_TO_FP6X4<"e2m3">;
690+
def CVT_e3m2x4_f32_rs_sf : CVT_TO_FP6X4<"e3m2">;
691+
654692
// FP4 conversions.
655693
def CVT_e2m1x2_f32_sf : NVPTXInst<(outs B16:$dst),
656694
(ins B32:$src1, B32:$src2, CvtMode:$mode),
@@ -667,6 +705,12 @@ let hasSideEffects = false in {
667705
"cvt.u8.u16 \t%e2m1x2_in, $src; \n\t",
668706
"cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t",
669707
"}}"), []>;
708+
709+
def CVT_e2m1x4_f32_rs_sf :
710+
NVPTXInst<(outs B16:$dst),
711+
(ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5,
712+
CvtMode:$mode),
713+
"cvt${mode:base}${mode:relu}.satfinite.e2m1x4.f32 \t$dst, {{$src1, $src2, $src3, $src4}}, $src5;">;
670714

671715
// UE8M0x2 conversions.
672716
class CVT_f32_to_ue8m0x2<string sat = ""> :

0 commit comments

Comments
 (0)