Skip to content

Commit 6709eda

Browse files
authored
Add threshold for zero results in log1pexp (#43)
* Add threshold for zero results in `log1pexp` * Fix tests on Julia 1.0
1 parent cbf9441 commit 6709eda

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.10"
4+
version = "0.3.11"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/basicfuns.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,14 @@ log1pexp(x::Real) = _log1pexp(float(x)) # ensures that BigInt/BigFloat, Int/Floa
161161
# Approximations based on Maechler (2012)
162162
# Argument `x` is a floating point number due to the definition of `log1pexp` above
163163
function _log1pexp(x::Real)
164-
x0, x1, x2 = _log1pexp_thresholds(x)
165-
if x < x0
164+
x1, x2, x3, x4 = _log1pexp_thresholds(x)
165+
if x < x1
166+
return zero(x)
167+
elseif x < x2
166168
return exp(x)
167-
elseif x < x1
169+
elseif x < x3
168170
return log1p(exp(x))
169-
elseif x < x2
171+
elseif x < x4
170172
return x + exp(-x)
171173
else
172174
return x
@@ -178,12 +180,12 @@ thresholds is slow. Therefore prefer version without thresholds in this case. =#
178180
_log1pexp(x::BigFloat) = x > 0 ? x + log1p(exp(-x)) : log1p(exp(x))
179181

180182
#=
181-
Returns thresholds x0, x1, x2 such that:
182-
183-
* log1pexp(x) ≈ exp(x) for x ≤ x0
184-
* log1pexp(x) ≈ log1p(exp(x)) for x0 < x ≤ x1
185-
* log1pexp(x) ≈ x + exp(-x) for x1 < x ≤ x2
186-
* log1pexp(x) ≈ x for x > x2
183+
Returns thresholds x1, x2, x3, x4 such that:
184+
* log1pexp(x) = 0 for x < x1
185+
* log1pexp(x) ≈ exp(x) for x < x2
186+
* log1pexp(x) ≈ log1p(exp(x)) for x2 ≤ x < x3
187+
* log1pexp(x) ≈ x + exp(-x) for x3 ≤ x < x4
188+
* log1pexp(x) ≈ x for x ≥ x4
187189
188190
where the tolerances of the approximations are on the order of eps(typeof(x)).
189191
For types for which `precision(x)` depends only on the type of `x`, the compiler
@@ -192,19 +194,20 @@ should optimize away all computations done here.
192194
@inline function _log1pexp_thresholds(x::Real)
193195
prec = precision(x)
194196
logtwo = oftype(x, IrrationalConstants.logtwo)
195-
x0 = -prec * logtwo
196-
x1 = (prec - 1) * logtwo / 2
197-
x2 = -x0 - log(-x0) * (1 + 1 / x0) # approximate root of e^-x == x * ϵ/2 via asymptotics of Lambert's W function
198-
return (x0, x1, x2)
197+
x1 = (exponent(nextfloat(zero(x))) - 1) * logtwo
198+
x2 = -prec * logtwo
199+
x3 = (prec - 1) * logtwo / 2
200+
x4 = -x2 - log(-x2) * (1 + 1 / x2) # approximate root of e^-x == x * ϵ/2 via asymptotics of Lambert's W function
201+
return (x1, x2, x3, x4)
199202
end
200203

201204
#=
202205
For common types we hard-code the thresholds to make absolutely sure they are not recomputed
203206
each time. Also, _log1pexp_thresholds is not elided by the compiler in Julia 1.0 / 1.6.
204207
=#
205-
@inline _log1pexp_thresholds(::Float64) = (-36.7368005696771, 18.021826694558577, 33.23111882352963)
206-
@inline _log1pexp_thresholds(::Float32) = (-16.635532f0, 7.9711924f0, 13.993f0)
207-
@inline _log1pexp_thresholds(::Float16) = (Float16(-7.625), Float16(3.467), Float16(5.86))
208+
@inline _log1pexp_thresholds(::Float64) = (-745.1332191019412, -36.7368005696771, 18.021826694558577, 33.23111882352963)
209+
@inline _log1pexp_thresholds(::Float32) = (-103.97208f0, -16.635532f0, 7.9711924f0, 13.993f0)
210+
@inline _log1pexp_thresholds(::Float16) = (Float16(-17.33), Float16(-7.625), Float16(3.467), Float16(5.86))
208211

209212
"""
210213
$(SIGNATURES)

test/basicfuns.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ end
128128
@test iszero(@inferred log1pexp(-1e4))
129129
@test iszero(@inferred log1pexp(-1f4))
130130

131+
# (almost) zero results
132+
for T in (Float16, Float32, Float64), x in (log(nextfloat(zero(T))), log(nextfloat(zero(T))) - 1)
133+
@test @inferred(log1pexp(x)) === log1p(exp(x))
134+
end
135+
136+
# hard-coded thresholds
137+
for T in (Float16, Float32, Float64)
138+
@test LogExpFunctions._log1pexp_thresholds(zero(T)) === invoke(LogExpFunctions._log1pexp_thresholds, Tuple{Real}, zero(T))
139+
end
140+
131141
# compare to accurate but slower implementation
132142
correct_log1pexp(x::Real) = x > 0 ? x + log1p(exp(-x)) : log1p(exp(x))
133143
# large range needed to cover all branches, for all floats (from Float16 to BigFloat)

0 commit comments

Comments
 (0)