Skip to content

Commit 35a95fe

Browse files
authored
[clang][NVPTX] Fix SM requirement of f32-tf32 rna satfinite conversion (#167836)
This change fixes the SM requirement of the f32 to tf32 conversion with `rna` rounding mode and `.satfinite` modifier. The current requirement specified is `sm_89` but this conversion is supported from `sm_80` onwards after it was added in PTX 8.1. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
1 parent b3c5491 commit 35a95fe

File tree

5 files changed

+27
-12
lines changed

5 files changed

+27
-12
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def __nvvm_f2bf16_rz : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
615615
def __nvvm_f2bf16_rz_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
616616

617617
def __nvvm_f2tf32_rna : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_80, PTX70>;
618-
def __nvvm_f2tf32_rna_satfinite : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_89, PTX81>;
618+
def __nvvm_f2tf32_rna_satfinite : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_80, PTX81>;
619619
def __nvvm_f2tf32_rn : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_90, PTX78>;
620620
def __nvvm_f2tf32_rn_relu : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_90, PTX78>;
621621
def __nvvm_f2tf32_rn_satfinite : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_100, PTX86>;

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_89 -target-feature +ptx81 -DPTX=81\
2929
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
3030
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM89 %s
31+
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_80 -target-feature +ptx81 -DPTX=81 \
32+
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
33+
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM80 %s
3134
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx78 -DPTX=78 \
3235
// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \
3336
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX78_SM90 %s
@@ -1025,6 +1028,10 @@ __device__ void nvvm_cvt_sm80() {
10251028

10261029
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
10271030
__nvvm_f2tf32_rna(1);
1031+
#if PTX >= 81
1032+
// CHECK_PTX81_SM80: call i32 @llvm.nvvm.f2tf32.rna.satfinite(float 1.000000e+00)
1033+
__nvvm_f2tf32_rna_satfinite(1.0f);
1034+
#endif
10281035
#endif
10291036
// CHECK: ret void
10301037
}
@@ -1058,9 +1065,6 @@ __device__ void nvvm_cvt_sm89() {
10581065
__nvvm_e5m2x2_to_f16x2_rn(0x4c4c);
10591066
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 19532)
10601067
__nvvm_e5m2x2_to_f16x2_rn_relu(0x4c4c);
1061-
1062-
// CHECK_PTX81_SM89: call i32 @llvm.nvvm.f2tf32.rna.satfinite(float 1.000000e+00)
1063-
__nvvm_f2tf32_rna_satfinite(1.0f);
10641068
#endif
10651069
// CHECK: ret void
10661070
}

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ let hasSideEffects = false in {
683683
defm CVT_to_tf32_rn_relu : CVT_TO_TF32<"rn.relu">;
684684
defm CVT_to_tf32_rz_relu : CVT_TO_TF32<"rz.relu">;
685685
defm CVT_to_tf32_rna : CVT_TO_TF32<"rna", [hasPTX<70>, hasSM<80>]>;
686-
defm CVT_to_tf32_rna_satf : CVT_TO_TF32<"rna.satfinite", [hasPTX<81>, hasSM<89>]>;
686+
defm CVT_to_tf32_rna_satf : CVT_TO_TF32<"rna.satfinite", [hasPTX<81>, hasSM<80>]>;
687687

688688
defm CVT_to_tf32_rn_satf : CVT_TO_TF32<"rn.satfinite", [hasPTX<86>, hasSM<100>]>;
689689
defm CVT_to_tf32_rz_satf : CVT_TO_TF32<"rz.satfinite", [hasPTX<86>, hasSM<100>]>;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
3+
; RUN: %if ptxas-sm_80 && ptxas-isa-8.1 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | %ptxas-verify -arch=sm_80 %}
4+
5+
; CHECK-LABEL: cvt_rna_satfinite_tf32_f32
6+
define i32 @cvt_rna_satfinite_tf32_f32(float %f1) {
7+
; CHECK-LABEL: cvt_rna_satfinite_tf32_f32(
8+
; CHECK: {
9+
; CHECK-NEXT: .reg .b32 %r<3>;
10+
; CHECK-EMPTY:
11+
; CHECK-NEXT: // %bb.0:
12+
; CHECK-NEXT: ld.param.b32 %r1, [cvt_rna_satfinite_tf32_f32_param_0];
13+
; CHECK-NEXT: cvt.rna.satfinite.tf32.f32 %r2, %r1;
14+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
15+
; CHECK-NEXT: ret;
16+
%val = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %f1)
17+
ret i32 %val
18+
}

llvm/test/CodeGen/NVPTX/convert-sm89.ll

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,3 @@ define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
8484
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
8585
ret <2 x half> %val
8686
}
87-
88-
; CHECK-LABEL: cvt_rna_satfinite_tf32_f32
89-
define i32 @cvt_rna_satfinite_tf32_f32(float %f1) {
90-
; CHECK: cvt.rna.satfinite.tf32.f32
91-
%val = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %f1)
92-
ret i32 %val
93-
}

0 commit comments

Comments
 (0)