Skip to content

Commit f75b714

Browse files
Merge pull request #85 from JuliaMath/os/improve-NaNMath-perf
use Base function rather than Libm ones since they're faster
2 parents 1c53a33 + fce37ba commit f75b714

File tree

1 file changed

+43
-9
lines changed

1 file changed

+43
-9
lines changed

src/NaNMath.jl

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,54 @@ module NaNMath
33
using OpenLibm_jll
44
const libm = OpenLibm_jll.libopenlibm
55

6-
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
7-
:lgamma, :log1p)
6+
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh,
7+
:log, :log2, :log10, :log1p, :lgamma)
88
@eval begin
9-
Base.@assume_effects :total ($f)(x::Float64) = ccall(($(string(f)),libm), Float64, (Float64,), x)
10-
Base.@assume_effects :total ($f)(x::Float32) = ccall(($(string(f,"f")),libm), Float32, (Float32,), x)
11-
($f)(x::Real) = ($f)(float(x))
9+
function ($f)(x::Real)
10+
xf = float(x)
11+
x === xf && throw(MethodError($f, (x,)))
12+
($f)(xf)
13+
end
1214
if $f !== :lgamma
1315
($f)(x) = (Base.$f)(x)
1416
end
1517
end
1618
end
1719

20+
Base.@assume_effects :total lgamma(x::Float64) = ccall(("lgamma",libm), Float64, (Float64,), x)
21+
Base.@assume_effects :total lgamma(x::Float32) = ccall(("lgammaf",libm), Float32, (Float32,), x)
22+
23+
for f in (:sin, :cos, :tan)
24+
@eval begin
25+
function ($f)(x::T) where T<:Union{Float16, Float32, Float64}
26+
isinf(x) ? T(NaN) : (Base.$f)(x)
27+
end
28+
end
29+
end
30+
31+
for f in (:asin, :acos, :atanh)
32+
@eval begin
33+
function ($f)(x::T) where T<:Union{Float16, Float32, Float64}
34+
abs(x) > T(1) ? T(NaN) : (Base.$f)(x)
35+
end
36+
end
37+
end
38+
function acosh(x::T) where T<:Union{Float16, Float32, Float64}
39+
x < T(1) ? T(NaN) : acosh(x)
40+
end
41+
42+
for f in (:log, :log2, :log10)
43+
@eval begin
44+
function ($f)(x::T) where T<:Union{Float16, Float32, Float64}
45+
x < 0 ? T(NaN) : (Base.$f)(x)
46+
end
47+
end
48+
end
49+
50+
function log1p(x::T) where T<:Union{Float16, Float32, Float64}
51+
x < T(-1) ? T(NaN) : Base.log1p(x)
52+
end
53+
1854
for f in (:sqrt,)
1955
@eval ($f)(x) = (Base.$f)(x)
2056
end
@@ -23,10 +59,8 @@ for f in (:max, :min)
2359
@eval ($f)(x, y) = (Base.$f)(x, y)
2460
end
2561

26-
# Would be more efficient to remove the domain check in Base.sqrt(),
27-
# but this doesn't seem easy to do.
28-
Base.@assume_effects :nothrow sqrt(x::T) where {T<:Union{Float16, Float32, Float64}} = x < 0.0 ? T(NaN) : Base.sqrt(x)
29-
sqrt(x::T) where {T<:AbstractFloat} = x < 0.0 ? T(NaN) : Base.sqrt(x)
62+
sqrt(x::T) where {T<:Union{Float16, Float32, Float64}} = x < T(0) ? T(NaN) : Base.Intrinsics.sqrt_llvm(x)
63+
sqrt(x::T) where {T<:AbstractFloat} = x < T(0) ? T(NaN) : Base.sqrt(x)
3064
sqrt(x::Real) = sqrt(float(x))
3165

3266
# Don't override built-in ^ operator

0 commit comments

Comments
 (0)