|
1 | 1 | module LogExpFunctions
|
2 | 2 |
|
3 |
| -export xlogx, xlogy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, |
4 |
| - log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax!, softmax |
5 |
| - |
6 | 3 | using DocStringExtensions: SIGNATURES
|
7 |
| - |
8 | 4 | using Base: Math.@horner, @irrational
|
9 | 5 |
|
10 |
| -#### |
11 |
| -#### constants |
12 |
| -#### |
13 |
| - |
14 |
| -@irrational loghalf -0.6931471805599453094 log(big(0.5)) |
15 |
| -@irrational logtwo 0.6931471805599453094 log(big(2.)) |
16 |
| -@irrational logπ 1.1447298858494001741 log(big(π)) |
17 |
| -@irrational log2π 1.8378770664093454836 log(big(2.)*π) |
18 |
| -@irrational log4π 2.5310242469692907930 log(big(4.)*π) |
19 |
| - |
20 |
| -#### |
21 |
| -#### functions |
22 |
| -#### |
23 |
| - |
24 |
| -""" |
25 |
| -$(SIGNATURES) |
26 |
| -
|
27 |
| -Return `x * log(x)` for `x ≥ 0`, handling ``x = 0`` by taking the downward limit. |
28 |
| -
|
29 |
| -```jldoctest |
30 |
| -julia> xlogx(0) |
31 |
| -0.0 |
32 |
| -``` |
33 |
| -""" |
34 |
| -function xlogx(x::Number) |
35 |
| - result = x * log(x) |
36 |
| - ifelse(iszero(x), zero(result), result) |
37 |
| -end |
38 |
| - |
39 |
| -""" |
40 |
| -$(SIGNATURES) |
41 |
| -
|
42 |
| -Return `x * log(y)` for `y > 0` with correct limit at ``x = 0``. |
43 |
| -
|
44 |
| -```jldoctest |
45 |
| -julia> xlogy(0, 0) |
46 |
| -0.0 |
47 |
| -``` |
48 |
| -""" |
49 |
| -function xlogy(x::Number, y::Number) |
50 |
| - result = x * log(y) |
51 |
| - ifelse(iszero(x) && !isnan(y), zero(result), result) |
52 |
| -end |
53 |
| - |
54 |
| -""" |
55 |
| -$(SIGNATURES) |
56 |
| -
|
57 |
| -The [logistic](https://en.wikipedia.org/wiki/Logistic_function) sigmoid function mapping a |
58 |
| -real number to a value in the interval ``[0,1]``, |
59 |
| -
|
60 |
| -```math |
61 |
| -\\sigma(x) = \\frac{1}{e^{-x} + 1} = \\frac{e^x}{1+e^x}. |
62 |
| -``` |
63 |
| -
|
64 |
| -Its inverse is the [`logit`](@ref) function. |
65 |
| -""" |
66 |
| -logistic(x::Real) = inv(exp(-x) + one(x)) |
67 |
| - |
68 |
| -# The following bounds are precomputed versions of the following abstract |
69 |
| -# function, but the implicit interface for AbstractFloat doesn't uniformly |
70 |
| -# enforce that all floating point types implement nextfloat and prevfloat. |
71 |
| -# @inline function _logistic_bounds(x::AbstractFloat) |
72 |
| -# ( |
73 |
| -# logit(nextfloat(zero(float(x)))), |
74 |
| -# logit(prevfloat(one(float(x)))), |
75 |
| -# ) |
76 |
| -# end |
77 |
| - |
78 |
| -@inline _logistic_bounds(::Float16) = (Float16(-16.64), Float16(7.625)) |
79 |
| -@inline _logistic_bounds(::Float32) = (-103.27893f0, 16.635532f0) |
80 |
| -@inline _logistic_bounds(::Float64) = (-744.4400719213812, 36.7368005696771) |
81 |
| - |
82 |
| -function logistic(x::Union{Float16, Float32, Float64}) |
83 |
| - e = exp(x) |
84 |
| - lower, upper = _logistic_bounds(x) |
85 |
| - ifelse( |
86 |
| - x < lower, |
87 |
| - zero(x), |
88 |
| - ifelse( |
89 |
| - x > upper, |
90 |
| - one(x), |
91 |
| - e / (one(x) + e) |
92 |
| - ) |
93 |
| - ) |
94 |
| -end |
95 |
| - |
96 |
| -""" |
97 |
| -$(SIGNATURES) |
98 |
| -
|
99 |
| -The [logit](https://en.wikipedia.org/wiki/Logit) or log-odds transformation, |
100 |
| -
|
101 |
| -```math |
102 |
| -\\log\\left(\\frac{x}{1-x}\\right), \\text{where} 0 < x < 1 |
103 |
| -``` |
104 |
| -
|
105 |
| -Its inverse is the [`logistic`](@ref) function. |
106 |
| -""" |
107 |
| -logit(x::Real) = log(x / (one(x) - x)) |
108 |
| - |
109 |
| -""" |
110 |
| -$(SIGNATURES) |
111 |
| -
|
112 |
| -Return `log(1+x^2)` evaluated carefully for `abs(x)` very small or very large. |
113 |
| -""" |
114 |
| -log1psq(x::Real) = log1p(abs2(x)) |
115 |
| -function log1psq(x::Union{Float32,Float64}) |
116 |
| - ax = abs(x) |
117 |
| - ax < maxintfloat(x) ? log1p(abs2(ax)) : 2 * log(ax) |
118 |
| -end |
119 |
| - |
120 |
| -""" |
121 |
| -$(SIGNATURES) |
122 |
| -
|
123 |
| -Return `log(1+exp(x))` evaluated carefully for largish `x`. |
124 |
| -
|
125 |
| -This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) |
126 |
| -transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref). |
127 |
| -""" |
128 |
| -log1pexp(x::Real) = x < 18.0 ? log1p(exp(x)) : x < 33.3 ? x + exp(-x) : oftype(exp(-x), x) |
129 |
| -log1pexp(x::Float32) = x < 9.0f0 ? log1p(exp(x)) : x < 16.0f0 ? x + exp(-x) : oftype(exp(-x), x) |
130 |
| - |
131 |
| -""" |
132 |
| -$(SIGNATURES) |
133 |
| -
|
134 |
| -Return `log(1 - exp(x))` |
135 |
| -
|
136 |
| -See: |
137 |
| - * Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf) |
138 |
| -
|
139 |
| -Note: different than Maechler (2012), no negation inside parentheses |
140 |
| -""" |
141 |
| -log1mexp(x::Real) = x < loghalf ? log1p(-exp(x)) : log(-expm1(x)) |
142 |
| - |
143 |
| -""" |
144 |
| -$(SIGNATURES) |
145 |
| -
|
146 |
| -Return `log(2 - exp(x))` evaluated as `log1p(-expm1(x))` |
147 |
| -""" |
148 |
| -log2mexp(x::Real) = log1p(-expm1(x)) |
149 |
| - |
150 |
| -""" |
151 |
| -$(SIGNATURES) |
152 |
| -
|
153 |
| -Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse of |
154 |
| -[`log1pexp`](@ref) (aka “softplus”). |
155 |
| -""" |
156 |
| -logexpm1(x::Real) = x <= 18.0 ? log(expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x) |
157 |
| - |
158 |
| -logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x) |
159 |
| - |
160 |
| -const softplus = log1pexp |
161 |
| -const invsoftplus = logexpm1 |
162 |
| - |
163 |
| -""" |
164 |
| -$(SIGNATURES) |
165 |
| -
|
166 |
| -Return `log(1 + x) - x`. |
167 |
| -
|
168 |
| -Use naive calculation or range reduction outside kernel range. Accurate ~2ulps for all `x`. |
169 |
| -""" |
170 |
| -function log1pmx(x::Float64) |
171 |
| - if !(-0.7 < x < 0.9) |
172 |
| - return log1p(x) - x |
173 |
| - elseif x > 0.315 |
174 |
| - u = (x-0.5)/1.5 |
175 |
| - return _log1pmx_ker(u) - 9.45348918918356180e-2 - 0.5*u |
176 |
| - elseif x > -0.227 |
177 |
| - return _log1pmx_ker(x) |
178 |
| - elseif x > -0.4 |
179 |
| - u = (x+0.25)/0.75 |
180 |
| - return _log1pmx_ker(u) - 3.76820724517809274e-2 + 0.25*u |
181 |
| - elseif x > -0.6 |
182 |
| - u = (x+0.5)*2.0 |
183 |
| - return _log1pmx_ker(u) - 1.93147180559945309e-1 + 0.5*u |
184 |
| - else |
185 |
| - u = (x+0.625)/0.375 |
186 |
| - return _log1pmx_ker(u) - 3.55829253011726237e-1 + 0.625*u |
187 |
| - end |
188 |
| -end |
189 |
| - |
190 |
| -""" |
191 |
| -$(SIGNATURES) |
192 |
| -
|
193 |
| -Return `log(x) - x + 1` carefully evaluated. |
194 |
| -""" |
195 |
| -function logmxp1(x::Float64) |
196 |
| - if x <= 0.3 |
197 |
| - return (log(x) + 1.0) - x |
198 |
| - elseif x <= 0.4 |
199 |
| - u = (x-0.375)/0.375 |
200 |
| - return _log1pmx_ker(u) - 3.55829253011726237e-1 + 0.625*u |
201 |
| - elseif x <= 0.6 |
202 |
| - u = 2.0*(x-0.5) |
203 |
| - return _log1pmx_ker(u) - 1.93147180559945309e-1 + 0.5*u |
204 |
| - else |
205 |
| - return log1pmx(x - 1.0) |
206 |
| - end |
207 |
| -end |
208 |
| - |
209 |
| -# The kernel of log1pmx |
210 |
| -# Accuracy within ~2ulps for -0.227 < x < 0.315 |
211 |
| -function _log1pmx_ker(x::Float64) |
212 |
| - r = x/(x+2.0) |
213 |
| - t = r*r |
214 |
| - w = @horner(t, |
215 |
| - 6.66666666666666667e-1, # 2/3 |
216 |
| - 4.00000000000000000e-1, # 2/5 |
217 |
| - 2.85714285714285714e-1, # 2/7 |
218 |
| - 2.22222222222222222e-1, # 2/9 |
219 |
| - 1.81818181818181818e-1, # 2/11 |
220 |
| - 1.53846153846153846e-1, # 2/13 |
221 |
| - 1.33333333333333333e-1, # 2/15 |
222 |
| - 1.17647058823529412e-1) # 2/17 |
223 |
| - hxsq = 0.5*x*x |
224 |
| - r*(hxsq+w*t)-hxsq |
225 |
| -end |
226 |
| - |
227 |
| -""" |
228 |
| -$(SIGNATURES) |
229 |
| -
|
230 |
| -Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling |
231 |
| -non-finite values. |
232 |
| -""" |
233 |
| -function logaddexp(x::Real, y::Real) |
234 |
| - # ensure Δ = 0 if x = y = ± Inf |
235 |
| - Δ = ifelse(x == y, zero(x - y), abs(x - y)) |
236 |
| - max(x, y) + log1pexp(-Δ) |
237 |
| -end |
238 |
| - |
239 |
| -Base.@deprecate logsumexp(x::Real, y::Real) logaddexp(x, y) |
240 |
| - |
241 |
| -""" |
242 |
| -$(SIGNATURES) |
243 |
| -
|
244 |
| -Return `log(abs(exp(x) - exp(y)))`, preserving numerical accuracy. |
245 |
| -""" |
246 |
| -logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y)) |
247 |
| - |
248 |
| -""" |
249 |
| -$(SIGNATURES) |
250 |
| -
|
251 |
| -Overwrite `r` with the `softmax` (or _normalized exponential_) transformation of `x` |
252 |
| -
|
253 |
| -That is, `r` is overwritten with `exp.(x)`, normalized to sum to 1. |
254 |
| -
|
255 |
| -See the [Wikipedia entry](https://en.wikipedia.org/wiki/Softmax_function) |
256 |
| -""" |
257 |
| -function softmax!(r::AbstractArray{R}, x::AbstractArray{T}) where {R<:AbstractFloat,T<:Real} |
258 |
| - n = length(x) |
259 |
| - length(r) == n || throw(DimensionMismatch("Inconsistent array lengths.")) |
260 |
| - u = maximum(x) |
261 |
| - s = 0. |
262 |
| - @inbounds for i = 1:n |
263 |
| - s += (r[i] = exp(x[i] - u)) |
264 |
| - end |
265 |
| - invs = convert(R, inv(s)) |
266 |
| - @inbounds for i = 1:n |
267 |
| - r[i] *= invs |
268 |
| - end |
269 |
| - r |
270 |
| -end |
271 |
| - |
272 |
| -""" |
273 |
| -$(SIGNATURES) |
274 |
| -
|
275 |
| -Return the [`softmax transformation`](https://en.wikipedia.org/wiki/Softmax_function) |
276 |
| -applied to `x` *in place*. |
277 |
| -""" |
278 |
| -softmax!(x::AbstractArray{<:AbstractFloat}) = softmax!(x, x) |
279 |
| - |
280 |
| -""" |
281 |
| -$(SIGNATURES) |
282 |
| -
|
283 |
| -Return the [`softmax transformation`](https://en.wikipedia.org/wiki/Softmax_function) |
284 |
| -applied to `x`. |
285 |
| -""" |
286 |
| -softmax(x::AbstractArray{<:Real}) = softmax!(similar(x, Float64), x) |
| 6 | +export loghalf, logtwo, logπ, log2π, log4π |
| 7 | +export xlogx, xlogy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, |
| 8 | + softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax, |
| 9 | + softmax! |
287 | 10 |
|
| 11 | +include("constants.jl") |
| 12 | +include("basicfuns.jl") |
288 | 13 | include("logsumexp.jl")
|
289 | 14 |
|
290 | 15 | end # module
|
0 commit comments