Skip to content

Commit 281a6c1

Browse files
committed
[NVVMReflect] Recognize __CUDA_PREC_SQRT
The `__nv_sqrtf` intrinsic in libdevice.bc, defined by NVIDIA, depends not only on `__nvvm_reflect("__CUDA_FTZ")` but also on `__nvvm_reflect("__CUDA_PREC_SQRT")`. However, the NVVMReflect pass previously failed to recognize `__CUDA_PREC_SQRT`, causing its value to default to `0`. This change enables the NVVMReflect pass to correctly pick up the module flag "nvvm-reflect-prec-sqrt", which Clang sets based on the `-fcuda-prec-sqrt` flag, ensuring proper behavior.
1 parent 203f061 commit 281a6c1

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

llvm/lib/Target/NVPTX/NVVMReflect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
173173
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
174174
F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
175175
ReflectVal = Flag->getSExtValue();
176+
} else if (ReflectArg == "__CUDA_PREC_SQRT") {
177+
// Try to pull __CUDA_PREC_SQRT from the nvvm-reflect-prec-sqrt module
178+
// flag.
179+
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
180+
F.getParent()->getModuleFlag("nvvm-reflect-prec-sqrt")))
181+
ReflectVal = Flag->getSExtValue();
176182
} else if (ReflectArg == "__CUDA_ARCH") {
177183
ReflectVal = SmVersion * 10;
178184
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; We run nvvm-reflect (and then optimize) this module twice, once with metadata
2+
; that enables precise sqrt, and again with metadata that disables it.
3+
4+
; RUN: cat %s > %t.noprec
5+
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-prec-sqrt", i32 0}' >> %t.noprec
6+
; RUN: opt %t.noprec -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
7+
; RUN: | FileCheck %s --check-prefix=PREC_SQRT_0 --check-prefix=CHECK
8+
9+
; RUN: cat %s > %t.prec
10+
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-prec-sqrt", i32 1}' >> %t.prec
11+
; RUN: opt %t.prec -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
12+
; RUN: | FileCheck %s --check-prefix=PREC_SQRT_1 --check-prefix=CHECK
13+
14+
@.str = private unnamed_addr constant [17 x i8] c"__CUDA_PREC_SQRT\00", align 1
15+
16+
declare i32 @__nvvm_reflect(ptr)
17+
18+
; CHECK-LABEL: @foo
19+
define i32 @foo() {
20+
; CHECK-NOT: call i32 @__nvvm_reflect
21+
%reflect = call i32 @__nvvm_reflect(ptr @.str)
22+
; PREC_SQRT_0: ret i32 0
23+
; PREC_SQRT_1: ret i32 1
24+
ret i32 %reflect
25+
}
26+
27+
!llvm.module.flags = !{!0}
28+
; A module flag is added to the end of this file by the RUN lines at the top.

0 commit comments

Comments
 (0)