diff --git a/src/basicfuns.jl b/src/basicfuns.jl index e2774d9..5cc0df2 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -288,31 +288,18 @@ $(SIGNATURES) Return `log(1 + x) - x`. Use naive calculation or range reduction outside kernel range. Accurate ~2ulps for all `x`. -This will fall back to the naive calculation for argument types different from `Float64`. +This will fall back to the naive calculation for argument types different from `Float32, Float64`. """ -function log1pmx(x::Float64) - if !(-0.7 < x < 0.9) +log1pmx(x::Real) = log1p(x) - x # Naive fallback + +function log1pmx(x::T) where T <: Union{Float32, Float64} + if !(T(-0.425) < x < T(0.4)) # accurate within 2 ULPs when log2(abs(log1p(x))) > 1.5 return log1p(x) - x - elseif x > 0.315 - u = (x-0.5)/1.5 - return _log1pmx_ker(u) - 9.45348918918356180e-2 - 0.5*u - elseif x > -0.227 - return _log1pmx_ker(x) - elseif x > -0.4 - u = (x+0.25)/0.75 - return _log1pmx_ker(u) - 3.76820724517809274e-2 + 0.25*u - elseif x > -0.6 - u = (x+0.5)*2.0 - return _log1pmx_ker(u) - 1.93147180559945309e-1 + 0.5*u else - u = (x+0.625)/0.375 - return _log1pmx_ker(u) - 3.55829253011726237e-1 + 0.625*u + return _log1pmx_ker(x) end end -# Naive fallback -log1pmx(x::Real) = log1p(x) - x - """ $(SIGNATURES) @@ -345,21 +332,30 @@ function logmxp1(x::Real) end # The kernel of log1pmx -# Accuracy within ~2ulps for -0.227 < x < 0.315 -function _log1pmx_ker(x::Float64) - r = x/(x+2.0) +# Accuracy within ~2ulps -0.227 < x < 0.315 for Float64 +# Accuracy <2.18ulps -0.425 < x < 0.425 for Float32 +# parameters foudn via Remez.jl, specifically: +# g(x) = evalpoly(x, big(2)./ntuple(i->2i+1, 50)) +# p = T.(Tuple(ratfn_minimax(g, (1e-3, (.425/(.425+2))^2), 8, 0)[1])) +function _log1pmx_ker(x::T) where T <: Union{Float32, Float64} + r = x / (x+2) t = r*r - w = @horner(t, - 6.66666666666666667e-1, # 2/3 - 4.00000000000000000e-1, # 2/5 - 2.85714285714285714e-1, # 2/7 - 2.22222222222222222e-1, # 2/9 - 1.81818181818181818e-1, # 2/11 - 1.53846153846153846e-1, # 2/13 - 1.33333333333333333e-1, # 2/15 - 1.17647058823529412e-1) # 2/17 - hxsq = 0.5*x*x - r*(hxsq+w*t)-hxsq + if T == Float32 + p = (0.6666658f0, 0.40008822f0, 0.2827692f0, 0.26246136f0) + else + p = (0.6666666666666669, + 0.3999999999997768, + 0.2857142857784595, + 0.2222222142048249, + 0.18181870670924566, + 0.15382646727504887, + 0.1337701340211177, + 0.11201972567415432, + 0.143418239946679) + end + w = evalpoly(t, p) + hxsq = x*x/2 + muladd(r, muladd(w, t, hxsq), -hxsq) end diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 33c6365..6a84b54 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -216,7 +216,8 @@ end @test log1pmx(2f0) ≈ log(3f0) - 2f0 for x in -0.5:0.1:10 - @test log1pmx(Float32(x)) ≈ Float32(log1pmx(x)) + @test log1pmx(Float32(x)) ≈ Float32(log1pmx(x)) atol=3*eps(Float32(x)) + @test log1pmx(x) ≈ Float64(log1pmx(big(x))) atol=3*eps(x) end end