Skip to content

Commit b30a4b9

Browse files
authored
improve log1pmx, add `Float32 implimentation
Through some clever use of Remez.jl (and some testing to better limit where the fallback is appropriate), we can remove almost all of the branches from the Float64 implementation and add a similar (but slightly faster) Float32 version.
1 parent a76a741 commit b30a4b9

File tree

1 file changed

+28
-33
lines changed

1 file changed

+28
-33
lines changed

src/basicfuns.jl

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -288,31 +288,18 @@ $(SIGNATURES)
288288
Return `log(1 + x) - x`.
289289
290290
Use naive calculation or range reduction outside kernel range. Accurate ~2ulps for all `x`.
291-
This will fall back to the naive calculation for argument types different from `Float64`.
291+
This will fall back to the naive calculation for argument types different from `Float32, Float64`.
292292
"""
293-
function log1pmx(x::Float64)
294-
if !(-0.7 < x < 0.9)
293+
log1pmx(x::Real) = log1p(x) - x # Naive fallback
294+
295+
function log1pmx(x::Union{Float32, Float64})
296+
if !(-0.425 < x < 0.4) # accurate within 2 ULPs when log2(abs(log1p(x))) > 1.5
295297
return log1p(x) - x
296-
elseif x > 0.315
297-
u = (x-0.5)/1.5
298-
return _log1pmx_ker(u) - 9.45348918918356180e-2 - 0.5*u
299-
elseif x > -0.227
300-
return _log1pmx_ker(x)
301-
elseif x > -0.4
302-
u = (x+0.25)/0.75
303-
return _log1pmx_ker(u) - 3.76820724517809274e-2 + 0.25*u
304-
elseif x > -0.6
305-
u = (x+0.5)*2.0
306-
return _log1pmx_ker(u) - 1.93147180559945309e-1 + 0.5*u
307298
else
308-
u = (x+0.625)/0.375
309-
return _log1pmx_ker(u) - 3.55829253011726237e-1 + 0.625*u
299+
return _log1pmx_ker(x)
310300
end
311301
end
312302

313-
# Naive fallback
314-
log1pmx(x::Real) = log1p(x) - x
315-
316303
"""
317304
$(SIGNATURES)
318305
@@ -345,21 +332,29 @@ function logmxp1(x::Real)
345332
end
346333

347334
# The kernel of log1pmx
348-
# Accuracy within ~2ulps for -0.227 < x < 0.315
349-
function _log1pmx_ker(x::Float64)
350-
r = x/(x+2.0)
335+
# Accuracy within ~2ulps -0.227 < x < 0.315 for Float64
336+
# Accuracy <2.18ulps -0.425 < x < 0.425 for Float32
337+
# parameters foudn via Remez.jl, specifically:
338+
# g(x) = evalpoly(x, big(2)./ntuple(i->2i+1, 50))
339+
# p = T.(Tuple(ratfn_minimax(g, (1e-3, (.425/(.425+2))^2), 8, 0)[1]))
340+
function _log1pmx_ker(x::T) where T <: Union{Float32, Float64}
341+
r = x / (x+2)
351342
t = r*r
352-
w = @horner(t,
353-
6.66666666666666667e-1, # 2/3
354-
4.00000000000000000e-1, # 2/5
355-
2.85714285714285714e-1, # 2/7
356-
2.22222222222222222e-1, # 2/9
357-
1.81818181818181818e-1, # 2/11
358-
1.53846153846153846e-1, # 2/13
359-
1.33333333333333333e-1, # 2/15
360-
1.17647058823529412e-1) # 2/17
361-
hxsq = 0.5*x*x
362-
r*(hxsq+w*t)-hxsq
343+
if T == Float32
344+
p = (0.6666658f0, 0.40008822f0, 0.2827692f0, 0.26246136f0)
345+
else
346+
p = (0.6666666666666669,
347+
0.3999999999997768,
348+
0.2857142857784595,
349+
0.2222222142048249,
350+
0.18181870670924566,
351+
0.15382646727504887,
352+
0.1337701340211177,
353+
0.11201972567415432,
354+
0.143418239946679)
355+
w = evalpoly(t, p)
356+
hxsq = x*x/2
357+
muladd(r, muladd(w, t, hxsq), -hxsq)
363358
end
364359

365360

0 commit comments

Comments
 (0)