Skip to content

Commit 76a23a7

Browse files
authored
Merge pull request #85 from DominiqueMakowski/master
Add "a" parameter to softplus() #83
2 parents 289114f + 5f1d99d commit 76a23a7

File tree

7 files changed

+86
-2
lines changed

7 files changed

+86
-2
lines changed

docs/src/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ logcosh
1818
logabssinh
1919
log1psq
2020
log1pexp
21+
softplus
22+
invsoftplus
2123
log1mexp
2224
log2mexp
2325
logexpm1

ext/LogExpFunctionsChangesOfVariablesExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,25 @@ function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1pexp), x::Real)
88
y = log1pexp(x)
99
return y, x - y
1010
end
11+
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(softplus), x::Real)
12+
return ChangesOfVariables.with_logabsdet_jacobian(log1pexp, x)
13+
end
14+
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(softplus),<:Real}, x::Real)
15+
y = f(x)
16+
return y, f.x * (x - y)
17+
end
1118

1219
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logexpm1), x::Real)
1320
y = logexpm1(x)
1421
return y, x - y
1522
end
23+
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(invsoftplus), x::Real)
24+
return ChangesOfVariables.with_logabsdet_jacobian(logexpm1, x)
25+
end
26+
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(invsoftplus),<:Real}, x::Real)
27+
y = f(x)
28+
return y, f.x * (x - y)
29+
end
1630

1731
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1mexp), x::Real)
1832
y = log1mexp(x)

ext/LogExpFunctionsInverseFunctionsExt.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,14 @@ InverseFunctions.inverse(::typeof(logitexp)) = loglogistic
2222
InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp
2323
InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic
2424

25+
InverseFunctions.inverse(::typeof(softplus)) = invsoftplus
26+
function InverseFunctions.inverse(f::Base.Fix2{typeof(softplus),<:Real})
27+
Base.Fix2(invsoftplus, f.x)
28+
end
29+
30+
InverseFunctions.inverse(::typeof(invsoftplus)) = softplus
31+
function InverseFunctions.inverse(f::Base.Fix2{typeof(invsoftplus),<:Real})
32+
Base.Fix2(softplus, f.x)
33+
end
34+
2535
end # module

src/basicfuns.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`.
165165
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
166166
transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref).
167167
168+
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
169+
transformation (in its default parametrization, see [`softplus`](@ref)), being a smooth approximation to `max(0,x)`.
170+
168171
See:
169172
* Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf)
170173
"""
@@ -257,8 +260,27 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o
257260
logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
258261
logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x)
259262

260-
const softplus = log1pexp
261-
const invsoftplus = logexpm1
263+
"""
264+
$(SIGNATURES)
265+
266+
The generalized `softplus` function (Wiemann et al., 2024) takes an additional optional parameter `a` that control
267+
the approximation error with respect to the linear spline. It defaults to `a=1.0`, in which case the softplus is
268+
equivalent to [`log1pexp`](@ref).
269+
270+
See:
271+
* Wiemann, P. F., Kneib, T., & Hambuckers, J. (2024). Using the softplus function to construct alternative link functions in generalized linear models and beyond. Statistical Papers, 65(5), 3155-3180.
272+
"""
273+
softplus(x::Real) = log1pexp(x)
274+
softplus(x::Real, a::Real) = log1pexp(a * x) / a
275+
276+
"""
277+
$(SIGNATURES)
278+
279+
The inverse generalized `softplus` function (Wiemann et al., 2024). See [`softplus`](@ref).
280+
"""
281+
invsoftplus(y::Real) = logexpm1(y)
282+
invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a
283+
262284

263285
"""
264286
$(SIGNATURES)

test/basicfuns.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,16 @@ end
161161
end
162162
end
163163

164+
@testset "softplus" begin
165+
for T in (Int, Float64, Float32, Float16)
166+
@test @inferred(softplus(T(2))) === log1pexp(T(2))
167+
@test @inferred(softplus(T(2), 1)) isa float(T)
168+
@test @inferred(softplus(T(2), 1)) softplus(T(2))
169+
@test @inferred(softplus(T(2), 5)) softplus(5 * T(2)) / 5
170+
@test @inferred(softplus(T(2), 10)) softplus(10 * T(2)) / 10
171+
end
172+
end
173+
164174
@testset "log1mexp" begin
165175
for T in (Float64, Float32, Float16)
166176
@test @inferred(log1mexp(-T(1))) isa T
@@ -186,6 +196,16 @@ end
186196
end
187197
end
188198

199+
@testset "invsoftplus" begin
200+
for T in (Int, Float64, Float32, Float16)
201+
@test @inferred(invsoftplus(T(2))) === logexpm1(T(2))
202+
@test @inferred(invsoftplus(T(2), 1)) isa float(T)
203+
@test @inferred(invsoftplus(T(2), 1)) invsoftplus(T(2))
204+
@test @inferred(invsoftplus(T(2), 5)) invsoftplus(5 * T(2)) / 5
205+
@test @inferred(invsoftplus(T(2), 10)) invsoftplus(10 * T(2)) / 10
206+
end
207+
end
208+
189209
@testset "log1pmx" begin
190210
@test iszero(log1pmx(0.0))
191211
@test log1pmx(1.0) log(2.0) - 1.0

test/inverse.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
@testset "inverse.jl" begin
22
InverseFunctions.test_inverse(log1pexp, randn())
3+
InverseFunctions.test_inverse(softplus, randn())
4+
InverseFunctions.test_inverse(Base.Fix2(softplus, randexp()), randn())
5+
36
InverseFunctions.test_inverse(logexpm1, randexp())
7+
InverseFunctions.test_inverse(invsoftplus, randexp())
8+
InverseFunctions.test_inverse(Base.Fix2(invsoftplus, randexp()), randexp())
49

510
InverseFunctions.test_inverse(log1mexp, -randexp())
611

test/with_logabsdet_jacobian.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
@testset "with_logabsdet_jacobian" begin
22
derivative(f, x) = ChainRulesTestUtils.frule((ChainRulesTestUtils.NoTangent(), 1), f, x)[2]
3+
derivative(::typeof(softplus), x) = derivative(log1pexp, x)
4+
derivative(f::Base.Fix2{typeof(softplus),<:Real}, x) = derivative(log1pexp, f.x * x)
5+
derivative(::typeof(invsoftplus), x) = derivative(logexpm1, x)
6+
derivative(f::Base.Fix2{typeof(invsoftplus),<:Real}, x) = derivative(logexpm1, f.x * x)
37

48
x = randexp()
9+
y = randexp()
510

611
ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, x, derivative)
712
ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, -x, derivative)
13+
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, x, derivative)
14+
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, -x, derivative)
15+
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), x, derivative)
16+
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), -x, derivative)
817

918
ChangesOfVariables.test_with_logabsdet_jacobian(logexpm1, x, derivative)
19+
ChangesOfVariables.test_with_logabsdet_jacobian(invsoftplus, x, derivative)
20+
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(invsoftplus, y), x, derivative)
1021

1122
ChangesOfVariables.test_with_logabsdet_jacobian(log1mexp, -x, derivative)
1223

0 commit comments

Comments
 (0)