@@ -18,6 +18,8 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget
18
18
blocks_per_sm:: Union{Nothing,Int} = nothing
19
19
maxregs:: Union{Nothing,Int} = nothing
20
20
21
+ fastmath:: Bool = Base. JLOptions (). fast_math == 1
22
+
21
23
# deprecated; remove with next major version
22
24
exitable:: Union{Nothing,Bool} = nothing
23
25
unreachable:: Union{Nothing,Bool} = nothing
@@ -33,6 +35,7 @@ function Base.hash(target::PTXCompilerTarget, h::UInt)
33
35
h = hash (target. maxthreads, h)
34
36
h = hash (target. blocks_per_sm, h)
35
37
h = hash (target. maxregs, h)
38
+ h = hash (target. fastmath, h)
36
39
37
40
h
38
41
end
@@ -82,6 +85,7 @@ function Base.show(io::IO, @nospecialize(job::CompilerJob{PTXCompilerTarget}))
82
85
job. config. target. maxthreads != = nothing && print (io, " , maxthreads=$(job. config. target. maxthreads) " )
83
86
job. config. target. blocks_per_sm != = nothing && print (io, " , blocks_per_sm=$(job. config. target. blocks_per_sm) " )
84
87
job. config. target. maxregs != = nothing && print (io, " , maxregs=$(job. config. target. maxregs) " )
88
+ job. config. target. fastmath && print (io, " , fast math enabled" )
85
89
end
86
90
87
91
const ptx_intrinsics = (" vprintf" , " __assertfail" , " malloc" , " free" )
@@ -424,7 +428,7 @@ function nvvm_reflect!(fun::LLVM.Function)
424
428
# handle possible cases
425
429
# XXX : put some of these property in the compiler job?
426
430
# and/or first set the "nvvm-reflect-*" module flag like Clang does?
427
- fast_math = Base . JLOptions () . fast_math == 1
431
+ fast_math = current_job . config . target . fastmath
428
432
# NOTE: we follow nvcc's --use_fast_math
429
433
reflect_val = if reflect_arg == " __CUDA_FTZ"
430
434
# single-precision denormals support
@@ -433,7 +437,7 @@ function nvvm_reflect!(fun::LLVM.Function)
433
437
# single-precision floating-point division and reciprocals.
434
438
ConstantInt (reflect_typ, fast_math ? 0 : 1 )
435
439
elseif reflect_arg == " __CUDA_PREC_SQRT"
436
- # single-precision denormals support
440
+ # single-precision floating point square roots.
437
441
ConstantInt (reflect_typ, fast_math ? 0 : 1 )
438
442
elseif reflect_arg == " __CUDA_FMAD"
439
443
# contraction of floating-point multiplies and adds/subtracts into
0 commit comments