Skip to content

Commit b649ef0

Browse files
authored
Add fastmath flag to PTXCompilerTarget (#492)
1 parent 15f0077 commit b649ef0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/ptx.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget
1818
blocks_per_sm::Union{Nothing,Int} = nothing
1919
maxregs::Union{Nothing,Int} = nothing
2020

21+
fastmath::Bool = Base.JLOptions().fast_math == 1
22+
2123
# deprecated; remove with next major version
2224
exitable::Union{Nothing,Bool} = nothing
2325
unreachable::Union{Nothing,Bool} = nothing
@@ -33,6 +35,7 @@ function Base.hash(target::PTXCompilerTarget, h::UInt)
3335
h = hash(target.maxthreads, h)
3436
h = hash(target.blocks_per_sm, h)
3537
h = hash(target.maxregs, h)
38+
h = hash(target.fastmath, h)
3639

3740
h
3841
end
@@ -82,6 +85,7 @@ function Base.show(io::IO, @nospecialize(job::CompilerJob{PTXCompilerTarget}))
8285
job.config.target.maxthreads !== nothing && print(io, ", maxthreads=$(job.config.target.maxthreads)")
8386
job.config.target.blocks_per_sm !== nothing && print(io, ", blocks_per_sm=$(job.config.target.blocks_per_sm)")
8487
job.config.target.maxregs !== nothing && print(io, ", maxregs=$(job.config.target.maxregs)")
88+
job.config.target.fastmath && print(io, ", fast math enabled")
8589
end
8690

8791
const ptx_intrinsics = ("vprintf", "__assertfail", "malloc", "free")
@@ -424,7 +428,7 @@ function nvvm_reflect!(fun::LLVM.Function)
424428
# handle possible cases
425429
# XXX: put some of these property in the compiler job?
426430
# and/or first set the "nvvm-reflect-*" module flag like Clang does?
427-
fast_math = Base.JLOptions().fast_math == 1
431+
fast_math = current_job.config.target.fastmath
428432
# NOTE: we follow nvcc's --use_fast_math
429433
reflect_val = if reflect_arg == "__CUDA_FTZ"
430434
# single-precision denormals support
@@ -433,7 +437,7 @@ function nvvm_reflect!(fun::LLVM.Function)
433437
# single-precision floating-point division and reciprocals.
434438
ConstantInt(reflect_typ, fast_math ? 0 : 1)
435439
elseif reflect_arg == "__CUDA_PREC_SQRT"
436-
# single-precision denormals support
440+
# single-precision floating point square roots.
437441
ConstantInt(reflect_typ, fast_math ? 0 : 1)
438442
elseif reflect_arg == "__CUDA_FMAD"
439443
# contraction of floating-point multiplies and adds/subtracts into

0 commit comments

Comments
 (0)