Skip to content

Commit 150fa44

Browse files
simsuracedevmotion
andauthored
Add fallbacks for log1pmx and logmxp1 (#45)
* Add fallbacks for `log1pmx` and `logmxp1` This looks like it would make sense, given that the other functions dispatch on `Real` argument, and it also fixes #44. * Use less naive heuristic Co-authored-by: David Widmann <[email protected]> * Add some tests * Bump version Co-authored-by: David Widmann <[email protected]>
1 parent ca7806c commit 150fa44

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/basicfuns.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ $(SIGNATURES)
246246
Return `log(1 + x) - x`.
247247
248248
Use naive calculation or range reduction outside kernel range. Accurate ~2ulps for all `x`.
249+
This will fall back to the naive calculation for argument types different from `Float64`.
249250
"""
250251
function log1pmx(x::Float64)
251252
if !(-0.7 < x < 0.9)
@@ -267,10 +268,14 @@ function log1pmx(x::Float64)
267268
end
268269
end
269270

271+
# Naive fallback
272+
log1pmx(x::Real) = log1p(x) - x
273+
270274
"""
271275
$(SIGNATURES)
272276
273277
Return `log(x) - x + 1` carefully evaluated.
278+
This will fall back to the naive calculation for argument types different from `Float64`.
274279
"""
275280
function logmxp1(x::Float64)
276281
if x <= 0.3
@@ -286,6 +291,17 @@ function logmxp1(x::Float64)
286291
end
287292
end
288293

294+
# Naive fallback
295+
function logmxp1(x::Real)
296+
one_x = one(x)
297+
if 2 * x < one_x
298+
# for small values of `x` the other branch returns non-finite values
299+
return (log(x) + one_x) - x
300+
else
301+
return log1pmx(x - one_x)
302+
end
303+
end
304+
289305
# The kernel of log1pmx
290306
# Accuracy within ~2ulps for -0.227 < x < 0.315
291307
function _log1pmx_ker(x::Float64)

test/basicfuns.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,28 @@ end
177177
@test iszero(log1pmx(0.0))
178178
@test log1pmx(1.0) log(2.0) - 1.0
179179
@test log1pmx(2.0) log(3.0) - 2.0
180+
181+
@test iszero(log1pmx(0f0))
182+
@test log1pmx(1f0) log(2f0) - 1f0
183+
@test log1pmx(2f0) log(3f0) - 2f0
184+
185+
for x in -0.5:0.1:10
186+
@test log1pmx(Float32(x)) Float32(log1pmx(x))
187+
end
180188
end
181189

182190
@testset "logmxp1" begin
183191
@test iszero(logmxp1(1.0))
184192
@test logmxp1(2.0) log(2.0) - 1.0
185193
@test logmxp1(3.0) log(3.0) - 2.0
194+
195+
@test iszero(logmxp1(1f0))
196+
@test logmxp1(2f0) log(2f0) - 1f0
197+
@test logmxp1(3f0) log(3f0) - 2f0
198+
199+
for x in 0.1:0.1:10
200+
@test logmxp1(Float32(x)) Float32(logmxp1(x))
201+
end
186202
end
187203

188204
@testset "logsumexp" begin

0 commit comments

Comments
 (0)