Skip to content

Commit ec90912

Browse files
authored
[clang][NVPTX] Add remaining float to fp16 conversions (#167641)
This change adds intrinsics and clang builtins for the remaining float to fp16 conversions. This includes the following conversions: - float to bf16x2 - satfinite variants - float to f16x2 - satfinite variants - float to bf16 - satfinite variants - float to f16 - all variants Tests are added in `convert-sm80.ll` and `convert-sm80-sf.ll` for the intrinsics and in `builtins-nvptx.c` for the clang builtins.
1 parent ac68dd5 commit ec90912

File tree

7 files changed

+451
-9
lines changed

7 files changed

+451
-9
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,10 @@ def __nvvm_ff2bf16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)
579579
def __nvvm_ff2bf16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
580580
def __nvvm_ff2bf16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
581581
def __nvvm_ff2bf16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
582+
def __nvvm_ff2bf16x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX81>;
583+
def __nvvm_ff2bf16x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX81>;
584+
def __nvvm_ff2bf16x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX81>;
585+
def __nvvm_ff2bf16x2_rz_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX81>;
582586
def __nvvm_ff2bf16x2_rs :
583587
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
584588
SM<"100a", [SM_103a]>, PTX87>;
@@ -596,6 +600,10 @@ def __nvvm_ff2f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)"
596600
def __nvvm_ff2f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
597601
def __nvvm_ff2f16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
598602
def __nvvm_ff2f16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
603+
def __nvvm_ff2f16x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX81>;
604+
def __nvvm_ff2f16x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX81>;
605+
def __nvvm_ff2f16x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX81>;
606+
def __nvvm_ff2f16x2_rz_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX81>;
599607
def __nvvm_ff2f16x2_rs :
600608
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
601609
SM<"100a", [SM_103a]>, PTX87>;
@@ -613,6 +621,19 @@ def __nvvm_f2bf16_rn : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
613621
def __nvvm_f2bf16_rn_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
614622
def __nvvm_f2bf16_rz : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
615623
def __nvvm_f2bf16_rz_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
624+
def __nvvm_f2bf16_rn_satfinite : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX81>;
625+
def __nvvm_f2bf16_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX81>;
626+
def __nvvm_f2bf16_rz_satfinite : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX81>;
627+
def __nvvm_f2bf16_rz_relu_satfinite : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX81>;
628+
629+
def __nvvm_f2f16_rn : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX70>;
630+
def __nvvm_f2f16_rn_relu : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX70>;
631+
def __nvvm_f2f16_rz : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX70>;
632+
def __nvvm_f2f16_rz_relu : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX70>;
633+
def __nvvm_f2f16_rn_satfinite : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX81>;
634+
def __nvvm_f2f16_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX81>;
635+
def __nvvm_f2f16_rz_satfinite : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX81>;
636+
def __nvvm_f2f16_rz_relu_satfinite : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX81>;
616637

617638
def __nvvm_f2tf32_rna : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_80, PTX70>;
618639
def __nvvm_f2tf32_rna_satfinite : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_80, PTX81>;

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,16 @@ __device__ void nvvm_cvt_sm80() {
10071007
__nvvm_ff2bf16x2_rz(1, 1);
10081008
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
10091009
__nvvm_ff2bf16x2_rz_relu(1, 1);
1010+
#if PTX >= 81
1011+
// CHECK_PTX81_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
1012+
__nvvm_ff2bf16x2_rn_satfinite(1, 1);
1013+
// CHECK_PTX81_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
1014+
__nvvm_ff2bf16x2_rn_relu_satfinite(1, 1);
1015+
// CHECK_PTX81_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float 1.000000e+00, float 1.000000e+00)
1016+
__nvvm_ff2bf16x2_rz_satfinite(1, 1);
1017+
// CHECK_PTX81_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
1018+
__nvvm_ff2bf16x2_rz_relu_satfinite(1, 1);
1019+
#endif
10101020

10111021
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
10121022
__nvvm_ff2f16x2_rn(1, 1);
@@ -1016,6 +1026,16 @@ __device__ void nvvm_cvt_sm80() {
10161026
__nvvm_ff2f16x2_rz(1, 1);
10171027
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
10181028
__nvvm_ff2f16x2_rz_relu(1, 1);
1029+
#if PTX >= 81
1030+
// CHECK_PTX81_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
1031+
__nvvm_ff2f16x2_rn_satfinite(1, 1);
1032+
// CHECK_PTX81_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
1033+
__nvvm_ff2f16x2_rn_relu_satfinite(1, 1);
1034+
// CHECK_PTX81_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float 1.000000e+00, float 1.000000e+00)
1035+
__nvvm_ff2f16x2_rz_satfinite(1, 1);
1036+
// CHECK_PTX81_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
1037+
__nvvm_ff2f16x2_rz_relu_satfinite(1, 1);
1038+
#endif
10191039

10201040
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
10211041
__nvvm_f2bf16_rn(1);
@@ -1025,6 +1045,35 @@ __device__ void nvvm_cvt_sm80() {
10251045
__nvvm_f2bf16_rz(1);
10261046
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
10271047
__nvvm_f2bf16_rz_relu(1);
1048+
#if PTX >= 81
1049+
// CHECK_PTX81_SM80: call bfloat @llvm.nvvm.f2bf16.rn.satfinite(float 1.000000e+00)
1050+
__nvvm_f2bf16_rn_satfinite(1);
1051+
// CHECK_PTX81_SM80: call bfloat @llvm.nvvm.f2bf16.rn.relu.satfinite(float 1.000000e+00)
1052+
__nvvm_f2bf16_rn_relu_satfinite(1);
1053+
// CHECK_PTX81_SM80: call bfloat @llvm.nvvm.f2bf16.rz.satfinite(float 1.000000e+00)
1054+
__nvvm_f2bf16_rz_satfinite(1);
1055+
// CHECK_PTX81_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu.satfinite(float 1.000000e+00)
1056+
__nvvm_f2bf16_rz_relu_satfinite(1);
1057+
#endif
1058+
1059+
// CHECK_PTX70_SM80: call half @llvm.nvvm.f2f16.rn(float 1.000000e+00)
1060+
__nvvm_f2f16_rn(1);
1061+
// CHECK_PTX70_SM80: call half @llvm.nvvm.f2f16.rn.relu(float 1.000000e+00)
1062+
__nvvm_f2f16_rn_relu(1);
1063+
// CHECK_PTX70_SM80: call half @llvm.nvvm.f2f16.rz(float 1.000000e+00)
1064+
__nvvm_f2f16_rz(1);
1065+
// CHECK_PTX70_SM80: call half @llvm.nvvm.f2f16.rz.relu(float 1.000000e+00)
1066+
__nvvm_f2f16_rz_relu(1);
1067+
#if PTX >= 81
1068+
// CHECK_PTX81_SM80: call half @llvm.nvvm.f2f16.rn.satfinite(float 1.000000e+00)
1069+
__nvvm_f2f16_rn_satfinite(1);
1070+
// CHECK_PTX81_SM80: call half @llvm.nvvm.f2f16.rn.relu.satfinite(float 1.000000e+00)
1071+
__nvvm_f2f16_rn_relu_satfinite(1);
1072+
// CHECK_PTX81_SM80: call half @llvm.nvvm.f2f16.rz.satfinite(float 1.000000e+00)
1073+
__nvvm_f2f16_rz_satfinite(1);
1074+
// CHECK_PTX81_SM80: call half @llvm.nvvm.f2f16.rz.relu.satfinite(float 1.000000e+00)
1075+
__nvvm_f2f16_rz_relu_satfinite(1);
1076+
#endif
10281077

10291078
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
10301079
__nvvm_f2tf32_rna(1);

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,14 +1566,19 @@ let TargetPrefix = "nvvm" in {
15661566

15671567
foreach rnd = ["rn", "rz"] in {
15681568
foreach relu = ["", "_relu"] in {
1569-
def int_nvvm_ff2bf16x2_ # rnd # relu : NVVMBuiltin,
1570-
PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty]>;
1571-
1572-
def int_nvvm_ff2f16x2_ # rnd # relu : NVVMBuiltin,
1573-
PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty]>;
1574-
1575-
def int_nvvm_f2bf16_ # rnd # relu : NVVMBuiltin,
1576-
PureIntrinsic<[llvm_bfloat_ty], [llvm_float_ty]>;
1569+
foreach satfinite = ["", "_satfinite"] in {
1570+
def int_nvvm_ff2bf16x2_ # rnd # relu # satfinite : NVVMBuiltin,
1571+
PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty]>;
1572+
1573+
def int_nvvm_ff2f16x2_ # rnd # relu # satfinite : NVVMBuiltin,
1574+
PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty]>;
1575+
1576+
def int_nvvm_f2bf16_ # rnd # relu # satfinite : NVVMBuiltin,
1577+
PureIntrinsic<[llvm_bfloat_ty], [llvm_float_ty]>;
1578+
1579+
def int_nvvm_f2f16_ # rnd # relu # satfinite : NVVMBuiltin,
1580+
PureIntrinsic<[llvm_half_ty], [llvm_float_ty]>;
1581+
}
15771582
}
15781583
}
15791584

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,15 @@ let hasSideEffects = false in {
595595
defm CVT_bf16 : CVT_FROM_ALL<"bf16", B16, [hasPTX<78>, hasSM<90>]>;
596596
defm CVT_f32 : CVT_FROM_ALL<"f32", B32>;
597597
defm CVT_f64 : CVT_FROM_ALL<"f64", B64>;
598+
599+
multiclass CVT_FROM_FLOAT_SATFINITE<string ToName, RegisterClass RC> {
600+
def _f32_sf :
601+
BasicFlagsNVPTXInst<(outs RC:$dst),
602+
(ins B32:$src), (ins CvtMode:$mode),
603+
"cvt${mode:base}${mode:relu}.satfinite." # ToName # ".f32">;
604+
}
605+
defm CVT_bf16 : CVT_FROM_FLOAT_SATFINITE<"bf16", B16>;
606+
defm CVT_f16 : CVT_FROM_FLOAT_SATFINITE<"f16", B16>;
598607

599608
// These cvts are different from those above: The source and dest registers
600609
// are of the same type.
@@ -611,6 +620,11 @@ let hasSideEffects = false in {
611620
(ins B32:$src1, B32:$src2), (ins CvtMode:$mode),
612621
"cvt${mode:base}${mode:relu}." # FromName # ".f32">,
613622
Requires<[hasPTX<70>, hasSM<80>]>;
623+
624+
def _f32_sf :
625+
BasicFlagsNVPTXInst<(outs RC:$dst),
626+
(ins B32:$src1, B32:$src2), (ins CvtMode:$mode),
627+
"cvt${mode:base}${mode:relu}.satfinite." # FromName # ".f32">;
614628
}
615629

616630
defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", B32>;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1952,7 +1952,12 @@ def : Pat<(int_nvvm_ff2bf16x2_rn f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, C
19521952
def : Pat<(int_nvvm_ff2bf16x2_rn_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRN_RELU)>;
19531953
def : Pat<(int_nvvm_ff2bf16x2_rz f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ)>;
19541954
def : Pat<(int_nvvm_ff2bf16x2_rz_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ_RELU)>;
1955-
1955+
let Predicates = [hasPTX<81>, hasSM<80>] in {
1956+
def : Pat<(int_nvvm_ff2bf16x2_rn_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRN)>;
1957+
def : Pat<(int_nvvm_ff2bf16x2_rn_relu_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRN_RELU)>;
1958+
def : Pat<(int_nvvm_ff2bf16x2_rz_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRZ)>;
1959+
def : Pat<(int_nvvm_ff2bf16x2_rz_relu_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRZ_RELU)>;
1960+
}
19561961
let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
19571962
def : Pat<(int_nvvm_ff2bf16x2_rs f32:$a, f32:$b, i32:$c),
19581963
(CVT_bf16x2_f32_rs $a, $b, $c, CvtRS)>;
@@ -1968,6 +1973,12 @@ def : Pat<(int_nvvm_ff2f16x2_rn f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, Cvt
19681973
def : Pat<(int_nvvm_ff2f16x2_rn_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN_RELU)>;
19691974
def : Pat<(int_nvvm_ff2f16x2_rz f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ)>;
19701975
def : Pat<(int_nvvm_ff2f16x2_rz_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ_RELU)>;
1976+
let Predicates = [hasPTX<81>, hasSM<80>] in {
1977+
def : Pat<(int_nvvm_ff2f16x2_rn_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRN)>;
1978+
def : Pat<(int_nvvm_ff2f16x2_rn_relu_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRN_RELU)>;
1979+
def : Pat<(int_nvvm_ff2f16x2_rz_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRZ)>;
1980+
def : Pat<(int_nvvm_ff2f16x2_rz_relu_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRZ_RELU)>;
1981+
}
19711982

19721983
let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
19731984
def : Pat<(int_nvvm_ff2f16x2_rs f32:$a, f32:$b, i32:$c),
@@ -1983,6 +1994,23 @@ def : Pat<(int_nvvm_f2bf16_rn f32:$a), (CVT_bf16_f32 $a, CvtRN)>;
19831994
def : Pat<(int_nvvm_f2bf16_rn_relu f32:$a), (CVT_bf16_f32 $a, CvtRN_RELU)>;
19841995
def : Pat<(int_nvvm_f2bf16_rz f32:$a), (CVT_bf16_f32 $a, CvtRZ)>;
19851996
def : Pat<(int_nvvm_f2bf16_rz_relu f32:$a), (CVT_bf16_f32 $a, CvtRZ_RELU)>;
1997+
let Predicates = [hasPTX<81>, hasSM<80>] in {
1998+
def : Pat<(int_nvvm_f2bf16_rz_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRZ)>;
1999+
def : Pat<(int_nvvm_f2bf16_rz_relu_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRZ_RELU)>;
2000+
def : Pat<(int_nvvm_f2bf16_rn_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRN)>;
2001+
def : Pat<(int_nvvm_f2bf16_rn_relu_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRN_RELU)>;
2002+
}
2003+
2004+
def : Pat<(int_nvvm_f2f16_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>;
2005+
def : Pat<(int_nvvm_f2f16_rn_relu f32:$a), (CVT_f16_f32 $a, CvtRN_RELU)>;
2006+
def : Pat<(int_nvvm_f2f16_rz f32:$a), (CVT_f16_f32 $a, CvtRZ)>;
2007+
def : Pat<(int_nvvm_f2f16_rz_relu f32:$a), (CVT_f16_f32 $a, CvtRZ_RELU)>;
2008+
let Predicates = [hasPTX<81>, hasSM<80>] in {
2009+
def : Pat<(int_nvvm_f2f16_rz_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRZ)>;
2010+
def : Pat<(int_nvvm_f2f16_rz_relu_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRZ_RELU)>;
2011+
def : Pat<(int_nvvm_f2f16_rn_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRN)>;
2012+
def : Pat<(int_nvvm_f2f16_rn_relu_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRN_RELU)>;
2013+
}
19862014

19872015
def : Pat<(int_nvvm_lohi_i2d i32:$a, i32:$b), (V2I32toI64 $a, $b)>;
19882016
def : Pat<(int_nvvm_d2i_lo f64:$a), (I64toI32L $a)>;

0 commit comments

Comments
 (0)