Skip to content

Commit 30b8bbf

Browse files
authored
Merge pull request #11 from JuliaStats/dw/statsfuns
2 parents e8c686c + 69e0d56 commit 30b8bbf

File tree

6 files changed

+434
-445
lines changed

6 files changed

+434
-445
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.2"
4+
version = "0.2.1"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/LogExpFunctions.jl

Lines changed: 6 additions & 281 deletions
Original file line numberDiff line numberDiff line change
@@ -1,290 +1,15 @@
11
module LogExpFunctions
22

3-
export xlogx, xlogy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
4-
log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax!, softmax
5-
63
using DocStringExtensions: SIGNATURES
7-
84
using Base: Math.@horner, @irrational
95

10-
####
11-
#### constants
12-
####
13-
14-
@irrational loghalf -0.6931471805599453094 log(big(0.5))
15-
@irrational logtwo 0.6931471805599453094 log(big(2.))
16-
@irrational logπ 1.1447298858494001741 log(big(π))
17-
@irrational log2π 1.8378770664093454836 log(big(2.)*π)
18-
@irrational log4π 2.5310242469692907930 log(big(4.)*π)
19-
20-
####
21-
#### functions
22-
####
23-
24-
"""
25-
$(SIGNATURES)
26-
27-
Return `x * log(x)` for `x ≥ 0`, handling ``x = 0`` by taking the downward limit.
28-
29-
```jldoctest
30-
julia> xlogx(0)
31-
0.0
32-
```
33-
"""
34-
function xlogx(x::Number)
35-
result = x * log(x)
36-
ifelse(iszero(x), zero(result), result)
37-
end
38-
39-
"""
40-
$(SIGNATURES)
41-
42-
Return `x * log(y)` for `y > 0` with correct limit at ``x = 0``.
43-
44-
```jldoctest
45-
julia> xlogy(0, 0)
46-
0.0
47-
```
48-
"""
49-
function xlogy(x::Number, y::Number)
50-
result = x * log(y)
51-
ifelse(iszero(x) && !isnan(y), zero(result), result)
52-
end
53-
54-
"""
55-
$(SIGNATURES)
56-
57-
The [logistic](https://en.wikipedia.org/wiki/Logistic_function) sigmoid function mapping a
58-
real number to a value in the interval ``[0,1]``,
59-
60-
```math
61-
\\sigma(x) = \\frac{1}{e^{-x} + 1} = \\frac{e^x}{1+e^x}.
62-
```
63-
64-
Its inverse is the [`logit`](@ref) function.
65-
"""
66-
logistic(x::Real) = inv(exp(-x) + one(x))
67-
68-
# The following bounds are precomputed versions of the following abstract
69-
# function, but the implicit interface for AbstractFloat doesn't uniformly
70-
# enforce that all floating point types implement nextfloat and prevfloat.
71-
# @inline function _logistic_bounds(x::AbstractFloat)
72-
# (
73-
# logit(nextfloat(zero(float(x)))),
74-
# logit(prevfloat(one(float(x)))),
75-
# )
76-
# end
77-
78-
@inline _logistic_bounds(::Float16) = (Float16(-16.64), Float16(7.625))
79-
@inline _logistic_bounds(::Float32) = (-103.27893f0, 16.635532f0)
80-
@inline _logistic_bounds(::Float64) = (-744.4400719213812, 36.7368005696771)
81-
82-
function logistic(x::Union{Float16, Float32, Float64})
83-
e = exp(x)
84-
lower, upper = _logistic_bounds(x)
85-
ifelse(
86-
x < lower,
87-
zero(x),
88-
ifelse(
89-
x > upper,
90-
one(x),
91-
e / (one(x) + e)
92-
)
93-
)
94-
end
95-
96-
"""
97-
$(SIGNATURES)
98-
99-
The [logit](https://en.wikipedia.org/wiki/Logit) or log-odds transformation,
100-
101-
```math
102-
\\log\\left(\\frac{x}{1-x}\\right), \\text{where} 0 < x < 1
103-
```
104-
105-
Its inverse is the [`logistic`](@ref) function.
106-
"""
107-
logit(x::Real) = log(x / (one(x) - x))
108-
109-
"""
110-
$(SIGNATURES)
111-
112-
Return `log(1+x^2)` evaluated carefully for `abs(x)` very small or very large.
113-
"""
114-
log1psq(x::Real) = log1p(abs2(x))
115-
function log1psq(x::Union{Float32,Float64})
116-
ax = abs(x)
117-
ax < maxintfloat(x) ? log1p(abs2(ax)) : 2 * log(ax)
118-
end
119-
120-
"""
121-
$(SIGNATURES)
122-
123-
Return `log(1+exp(x))` evaluated carefully for largish `x`.
124-
125-
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
126-
transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref).
127-
"""
128-
log1pexp(x::Real) = x < 18.0 ? log1p(exp(x)) : x < 33.3 ? x + exp(-x) : oftype(exp(-x), x)
129-
log1pexp(x::Float32) = x < 9.0f0 ? log1p(exp(x)) : x < 16.0f0 ? x + exp(-x) : oftype(exp(-x), x)
130-
131-
"""
132-
$(SIGNATURES)
133-
134-
Return `log(1 - exp(x))`
135-
136-
See:
137-
* Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf)
138-
139-
Note: different than Maechler (2012), no negation inside parentheses
140-
"""
141-
log1mexp(x::Real) = x < loghalf ? log1p(-exp(x)) : log(-expm1(x))
142-
143-
"""
144-
$(SIGNATURES)
145-
146-
Return `log(2 - exp(x))` evaluated as `log1p(-expm1(x))`
147-
"""
148-
log2mexp(x::Real) = log1p(-expm1(x))
149-
150-
"""
151-
$(SIGNATURES)
152-
153-
Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse of
154-
[`log1pexp`](@ref) (aka “softplus”).
155-
"""
156-
logexpm1(x::Real) = x <= 18.0 ? log(expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
157-
158-
logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x)
159-
160-
const softplus = log1pexp
161-
const invsoftplus = logexpm1
162-
163-
"""
164-
$(SIGNATURES)
165-
166-
Return `log(1 + x) - x`.
167-
168-
Use naive calculation or range reduction outside kernel range. Accurate ~2ulps for all `x`.
169-
"""
170-
function log1pmx(x::Float64)
171-
if !(-0.7 < x < 0.9)
172-
return log1p(x) - x
173-
elseif x > 0.315
174-
u = (x-0.5)/1.5
175-
return _log1pmx_ker(u) - 9.45348918918356180e-2 - 0.5*u
176-
elseif x > -0.227
177-
return _log1pmx_ker(x)
178-
elseif x > -0.4
179-
u = (x+0.25)/0.75
180-
return _log1pmx_ker(u) - 3.76820724517809274e-2 + 0.25*u
181-
elseif x > -0.6
182-
u = (x+0.5)*2.0
183-
return _log1pmx_ker(u) - 1.93147180559945309e-1 + 0.5*u
184-
else
185-
u = (x+0.625)/0.375
186-
return _log1pmx_ker(u) - 3.55829253011726237e-1 + 0.625*u
187-
end
188-
end
189-
190-
"""
191-
$(SIGNATURES)
192-
193-
Return `log(x) - x + 1` carefully evaluated.
194-
"""
195-
function logmxp1(x::Float64)
196-
if x <= 0.3
197-
return (log(x) + 1.0) - x
198-
elseif x <= 0.4
199-
u = (x-0.375)/0.375
200-
return _log1pmx_ker(u) - 3.55829253011726237e-1 + 0.625*u
201-
elseif x <= 0.6
202-
u = 2.0*(x-0.5)
203-
return _log1pmx_ker(u) - 1.93147180559945309e-1 + 0.5*u
204-
else
205-
return log1pmx(x - 1.0)
206-
end
207-
end
208-
209-
# The kernel of log1pmx
210-
# Accuracy within ~2ulps for -0.227 < x < 0.315
211-
function _log1pmx_ker(x::Float64)
212-
r = x/(x+2.0)
213-
t = r*r
214-
w = @horner(t,
215-
6.66666666666666667e-1, # 2/3
216-
4.00000000000000000e-1, # 2/5
217-
2.85714285714285714e-1, # 2/7
218-
2.22222222222222222e-1, # 2/9
219-
1.81818181818181818e-1, # 2/11
220-
1.53846153846153846e-1, # 2/13
221-
1.33333333333333333e-1, # 2/15
222-
1.17647058823529412e-1) # 2/17
223-
hxsq = 0.5*x*x
224-
r*(hxsq+w*t)-hxsq
225-
end
226-
227-
"""
228-
$(SIGNATURES)
229-
230-
Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling
231-
non-finite values.
232-
"""
233-
function logaddexp(x::Real, y::Real)
234-
# ensure Δ = 0 if x = y = ± Inf
235-
Δ = ifelse(x == y, zero(x - y), abs(x - y))
236-
max(x, y) + log1pexp(-Δ)
237-
end
238-
239-
Base.@deprecate logsumexp(x::Real, y::Real) logaddexp(x, y)
240-
241-
"""
242-
$(SIGNATURES)
243-
244-
Return `log(abs(exp(x) - exp(y)))`, preserving numerical accuracy.
245-
"""
246-
logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
247-
248-
"""
249-
$(SIGNATURES)
250-
251-
Overwrite `r` with the `softmax` (or _normalized exponential_) transformation of `x`
252-
253-
That is, `r` is overwritten with `exp.(x)`, normalized to sum to 1.
254-
255-
See the [Wikipedia entry](https://en.wikipedia.org/wiki/Softmax_function)
256-
"""
257-
function softmax!(r::AbstractArray{R}, x::AbstractArray{T}) where {R<:AbstractFloat,T<:Real}
258-
n = length(x)
259-
length(r) == n || throw(DimensionMismatch("Inconsistent array lengths."))
260-
u = maximum(x)
261-
s = 0.
262-
@inbounds for i = 1:n
263-
s += (r[i] = exp(x[i] - u))
264-
end
265-
invs = convert(R, inv(s))
266-
@inbounds for i = 1:n
267-
r[i] *= invs
268-
end
269-
r
270-
end
271-
272-
"""
273-
$(SIGNATURES)
274-
275-
Return the [`softmax transformation`](https://en.wikipedia.org/wiki/Softmax_function)
276-
applied to `x` *in place*.
277-
"""
278-
softmax!(x::AbstractArray{<:AbstractFloat}) = softmax!(x, x)
279-
280-
"""
281-
$(SIGNATURES)
282-
283-
Return the [`softmax transformation`](https://en.wikipedia.org/wiki/Softmax_function)
284-
applied to `x`.
285-
"""
286-
softmax(x::AbstractArray{<:Real}) = softmax!(similar(x, Float64), x)
6+
export loghalf, logtwo, logπ, log2π, log4π
7+
export xlogx, xlogy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
8+
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax,
9+
softmax!
28710

11+
include("constants.jl")
12+
include("basicfuns.jl")
28813
include("logsumexp.jl")
28914

29015
end # module

0 commit comments

Comments
 (0)