Skip to content

Commit 66ddb31

Browse files
authored
Merge pull request #94 from oscardssmith/patch-2
improve `log1pmx`, add `Float32` implimentation
2 parents c9adf5a + 786c5e4 commit 66ddb31

File tree

2 files changed

+31
-34
lines changed

2 files changed

+31
-34
lines changed

src/basicfuns.jl

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -288,31 +288,18 @@ $(SIGNATURES)
288288
Return `log(1 + x) - x`.
289289
290290
Use naive calculation or range reduction outside kernel range. Accurate ~2ulps for all `x`.
291-
This will fall back to the naive calculation for argument types different from `Float64`.
291+
This will fall back to the naive calculation for argument types different from `Float32, Float64`.
292292
"""
293-
function log1pmx(x::Float64)
294-
if !(-0.7 < x < 0.9)
293+
log1pmx(x::Real) = log1p(x) - x # Naive fallback
294+
295+
function log1pmx(x::T) where T <: Union{Float32, Float64}
296+
if !(T(-0.425) < x < T(0.4)) # accurate within 2 ULPs when log2(abs(log1p(x))) > 1.5
295297
return log1p(x) - x
296-
elseif x > 0.315
297-
u = (x-0.5)/1.5
298-
return _log1pmx_ker(u) - 9.45348918918356180e-2 - 0.5*u
299-
elseif x > -0.227
300-
return _log1pmx_ker(x)
301-
elseif x > -0.4
302-
u = (x+0.25)/0.75
303-
return _log1pmx_ker(u) - 3.76820724517809274e-2 + 0.25*u
304-
elseif x > -0.6
305-
u = (x+0.5)*2.0
306-
return _log1pmx_ker(u) - 1.93147180559945309e-1 + 0.5*u
307298
else
308-
u = (x+0.625)/0.375
309-
return _log1pmx_ker(u) - 3.55829253011726237e-1 + 0.625*u
299+
return _log1pmx_ker(x)
310300
end
311301
end
312302

313-
# Naive fallback
314-
log1pmx(x::Real) = log1p(x) - x
315-
316303
"""
317304
$(SIGNATURES)
318305
@@ -345,21 +332,30 @@ function logmxp1(x::Real)
345332
end
346333

347334
# The kernel of log1pmx
348-
# Accuracy within ~2ulps for -0.227 < x < 0.315
349-
function _log1pmx_ker(x::Float64)
350-
r = x/(x+2.0)
335+
# Accuracy within ~2ulps -0.227 < x < 0.315 for Float64
336+
# Accuracy <2.18ulps -0.425 < x < 0.425 for Float32
337+
# parameters foudn via Remez.jl, specifically:
338+
# g(x) = evalpoly(x, big(2)./ntuple(i->2i+1, 50))
339+
# p = T.(Tuple(ratfn_minimax(g, (1e-3, (.425/(.425+2))^2), 8, 0)[1]))
340+
function _log1pmx_ker(x::T) where T <: Union{Float32, Float64}
341+
r = x / (x+2)
351342
t = r*r
352-
w = @horner(t,
353-
6.66666666666666667e-1, # 2/3
354-
4.00000000000000000e-1, # 2/5
355-
2.85714285714285714e-1, # 2/7
356-
2.22222222222222222e-1, # 2/9
357-
1.81818181818181818e-1, # 2/11
358-
1.53846153846153846e-1, # 2/13
359-
1.33333333333333333e-1, # 2/15
360-
1.17647058823529412e-1) # 2/17
361-
hxsq = 0.5*x*x
362-
r*(hxsq+w*t)-hxsq
343+
if T == Float32
344+
p = (0.6666658f0, 0.40008822f0, 0.2827692f0, 0.26246136f0)
345+
else
346+
p = (0.6666666666666669,
347+
0.3999999999997768,
348+
0.2857142857784595,
349+
0.2222222142048249,
350+
0.18181870670924566,
351+
0.15382646727504887,
352+
0.1337701340211177,
353+
0.11201972567415432,
354+
0.143418239946679)
355+
end
356+
w = evalpoly(t, p)
357+
hxsq = x*x/2
358+
muladd(r, muladd(w, t, hxsq), -hxsq)
363359
end
364360

365361

test/basicfuns.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ end
216216
@test log1pmx(2f0) log(3f0) - 2f0
217217

218218
for x in -0.5:0.1:10
219-
@test log1pmx(Float32(x)) Float32(log1pmx(x))
219+
@test log1pmx(Float32(x)) Float32(log1pmx(x)) atol=3*eps(Float32(x))
220+
@test log1pmx(x) Float64(log1pmx(big(x))) atol=3*eps(x)
220221
end
221222
end
222223

0 commit comments

Comments
 (0)