Skip to content

Commit 35f18e3

Browse files
Add "a" parameter to softplus()
1 parent 289114f commit 35f18e3

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

src/basicfuns.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,14 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`.
165165
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
166166
transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref).
167167
168+
The generalized `softplus` function (Wiemann et al., 2024) takes an additional optional parameter `a` that control
169+
the approximation error with respect to the linear spline. It defaults to `a=1.0`, in which case the softplus is
170+
equivalent to `log1pexp`.
171+
168172
See:
169173
* Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf)
170-
"""
174+
* Wiemann, P. F., Kneib, T., & Hambuckers, J. (2024). Using the softplus function to construct alternative link functions in generalized linear models and beyond. Statistical Papers, 65(5), 3155-3180.
175+
"""
171176
log1pexp(x::Real) = _log1pexp(float(x)) # ensures that BigInt/BigFloat, Int/Float64 etc. dispatch to the same algorithm
172177

173178
# Approximations based on Maechler (2012)
@@ -255,10 +260,22 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o
255260
[`log1pexp`](@ref) (aka “softplus”).
256261
"""
257262
logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
258-
logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x)
263+
logexpm1(x::Float32) = x <= 9.0f0 ? log(expm1(x)) : x <= 16.0f0 ? x - exp(-x) : oftype(exp(-x), x)
264+
265+
function softplus(x; a::Real=1.0)
266+
if a == 1.0
267+
return log1pexp(x)
268+
end
269+
return log1pexp(a * x) / a
270+
end
271+
272+
function invsoftplus(y; a::Real=1.0)
273+
if a == 1.0
274+
return logexpm1(y)
275+
end
276+
return logexpm1(a * y) / a
277+
end
259278

260-
const softplus = log1pexp
261-
const invsoftplus = logexpm1
262279

263280
"""
264281
$(SIGNATURES)

0 commit comments

Comments
 (0)