Skip to content

Commit 4e50c54

Browse files
cossiodevmotion
andauthored
log1pexp (#37)
* log1pexp(x) for x < -37 Based on https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf * generic log1pexp * simpler * generic threshold * support Julia 1.0 expm1(::Float16) was not defined in Julia 1.0 So it's better (and more accurate) to compute the threshold in BigFloat, and convert to the appropriate float type in the end. Since this is generated it doesn't cost in terms of performance. * special thresholds * add exp branch * simplify thresholds * more comments * more tests * inline * generic x2 * comment * unnecessary broadcast * oftype -> float since x0 is of type float(x) anyway * comment * < instead of <= * x = float(_x) * hard-coded bounds for FLoat32, Float64 * typo * comment * typo * comment * compiler is smart enough we don't need generated thresholds! * special case log1pexp(x::BigFloat) dynamic thresholds for log1pexp(x::BigFloat) are slow, so use slower but accurate implementation in this case * rewrite comment * don't need float(x) * rewrite _log1pexp_thresholds (more readable?) * Float16 and typo * test at +/- 1 * comment tests * typo Co-authored-by: David Widmann <[email protected]> * one-line Co-authored-by: David Widmann <[email protected]> * better comments * test log1pexp with multiple precisions * bump version * hard-code same thresholds as given by generic fallback * Final fixes Co-authored-by: David Widmann <[email protected]>
1 parent 8ce6807 commit 4e50c54

File tree

4 files changed

+89
-20
lines changed

4 files changed

+89
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.7"
4+
version = "0.3.8"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/basicfuns.jl

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,59 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`.
152152
153153
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
154154
transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref).
155+
156+
See:
157+
* Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf)
155158
"""
156-
log1pexp(x::Real) = x < 18.0 ? log1p(exp(x)) : x < 33.3 ? x + exp(-x) : oftype(exp(-x), x)
157-
log1pexp(x::Float32) = x < 9.0f0 ? log1p(exp(x)) : x < 16.0f0 ? x + exp(-x) : oftype(exp(-x), x)
159+
log1pexp(x::Real) = _log1pexp(float(x)) # ensures that BigInt/BigFloat, Int/Float64 etc. dispatch to the same algorithm
160+
161+
# Approximations based on Maechler (2012)
162+
# Argument `x` is a floating point number due to the definition of `log1pexp` above
163+
function _log1pexp(x::Real)
164+
x0, x1, x2 = _log1pexp_thresholds(x)
165+
if x < x0
166+
return exp(x)
167+
elseif x < x1
168+
return log1p(exp(x))
169+
elseif x < x2
170+
return x + exp(-x)
171+
else
172+
return x
173+
end
174+
end
175+
176+
#= The precision of BigFloat cannot be computed from the type only and computing
177+
thresholds is slow. Therefore prefer version without thresholds in this case. =#
178+
_log1pexp(x::BigFloat) = x > 0 ? x + log1p(exp(-x)) : log1p(exp(x))
179+
180+
#=
181+
Returns thresholds x0, x1, x2 such that:
182+
183+
* log1pexp(x) ≈ exp(x) for x ≤ x0
184+
* log1pexp(x) ≈ log1p(exp(x)) for x0 < x ≤ x1
185+
* log1pexp(x) ≈ x + exp(-x) for x1 < x ≤ x2
186+
* log1pexp(x) ≈ x for x > x2
187+
188+
where the tolerances of the approximations are on the order of eps(typeof(x)).
189+
For types for which `precision(x)` depends only on the type of `x`, the compiler
190+
should optimize away all computations done here.
191+
=#
192+
@inline function _log1pexp_thresholds(x::Real)
193+
prec = precision(x)
194+
logtwo = oftype(x, IrrationalConstants.logtwo)
195+
x0 = -prec * logtwo
196+
x1 = (prec - 1) * logtwo / 2
197+
x2 = -x0 - log(-x0) * (1 + 1 / x0) # approximate root of e^-x == x * ϵ/2 via asymptotics of Lambert's W function
198+
return (x0, x1, x2)
199+
end
200+
201+
#=
202+
For common types we hard-code the thresholds to make absolutely sure they are not recomputed
203+
each time. Also, _log1pexp_thresholds is not elided by the compiler in Julia 1.0 / 1.6.
204+
=#
205+
@inline _log1pexp_thresholds(::Float64) = (-36.7368005696771, 18.021826694558577, 33.23111882352963)
206+
@inline _log1pexp_thresholds(::Float32) = (-16.635532f0, 7.9711924f0, 13.993f0)
207+
@inline _log1pexp_thresholds(::Float16) = (Float16(-7.625), Float16(3.467), Float16(5.86))
158208

159209
"""
160210
$(SIGNATURES)

test/basicfuns.jl

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,37 @@ end
110110
# log1pexp, log1mexp, log2mexp & logexpm1
111111

112112
@testset "log1pexp" begin
113-
@test log1pexp(2.0) log(1.0 + exp(2.0))
114-
@test log1pexp(-2.0) log(1.0 + exp(-2.0))
115-
@test log1pexp(10000) 10000.0
116-
@test log1pexp(-10000) 0.0
117-
118-
@test log1pexp(2f0) log(1f0 + exp(2f0))
119-
@test log1pexp(-2f0) log(1f0 + exp(-2f0))
120-
@test log1pexp(10000f0) 10000f0
121-
@test log1pexp(-10000f0) 0f0
113+
for T in (Float16, Float32, Float64, BigFloat), x in 1:40
114+
@test (@inferred log1pexp(+log(T(x)))) T(log1p(big(x)))
115+
@test (@inferred log1pexp(-log(T(x)))) T(log1p(1/big(x)))
116+
end
117+
118+
# special values
119+
@test (@inferred log1pexp(0)) log(2)
120+
@test (@inferred log1pexp(0f0)) log(2)
121+
@test (@inferred log1pexp(big(0))) log(2)
122+
@test (@inferred log1pexp(+1)) log1p(ℯ)
123+
@test (@inferred log1pexp(-1)) log1p(ℯ) - 1
124+
125+
# large arguments
126+
@test (@inferred log1pexp(1e4)) 1e4
127+
@test (@inferred log1pexp(1f4)) 1f4
128+
@test iszero(@inferred log1pexp(-1e4))
129+
@test iszero(@inferred log1pexp(-1f4))
130+
131+
# compare to accurate but slower implementation
132+
correct_log1pexp(x::Real) = x > 0 ? x + log1p(exp(-x)) : log1p(exp(x))
133+
# large range needed to cover all branches, for all floats (from Float16 to BigFloat)
134+
for T in (Int, Float16, Float32, Float64, BigInt, BigFloat), x in -300:300
135+
@test (@inferred log1pexp(T(x))) float(T)(correct_log1pexp(big(x)))
136+
end
137+
# test BigFloat with multiple precisions
138+
for prec in (10, 20, 50, 100), x in -300:300
139+
setprecision(prec) do
140+
y = big(float(x))
141+
@test @inferred(log1pexp(y)) correct_log1pexp(y)
142+
end
143+
end
122144
end
123145

124146
@testset "log1mexp" begin

test/chainrules.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,11 @@
5757
test_rrule(logcosh, x)
5858
end
5959

60-
# test all branches of `log1pexp`
61-
for x in (-20.9, 15.4, 41.5)
62-
test_frule(log1pexp, x)
63-
test_rrule(log1pexp, x)
64-
end
65-
for x in (8.3f0, 12.5f0, 21.2f0)
66-
test_frule(log1pexp, x; rtol=1f-3, atol=1f-3)
67-
test_rrule(log1pexp, x; rtol=1f-3, atol=1f-3)
60+
@testset "log1pexp" begin
61+
for absx in (0, 1, 2, 10, 15, 20, 40), x in (-absx, absx)
62+
test_scalar(log1pexp, Float64(x))
63+
test_scalar(log1pexp, Float32(x); rtol=1f-3, atol=1f-3)
64+
end
6865
end
6966

7067
for x in (-10.2, -3.3, -0.3)

0 commit comments

Comments
 (0)