Skip to content

Commit 004812d

Browse files
committed
Move logsumexp implementation to separate file
Implementation unaffected, only reorganizes code in a similar way as in LogExpFunctions 0.2.0
1 parent 8f6f530 commit 004812d

File tree

3 files changed

+97
-97
lines changed

3 files changed

+97
-97
lines changed

src/LogExpFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ export xlogx, xlogy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, log
1010

1111
include("constants.jl")
1212
include("basicfuns.jl")
13+
include("logsumexp.jl")
1314

1415
end # module

src/basicfuns.jl

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -226,103 +226,6 @@ logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
226226
"""
227227
$(SIGNATURES)
228228
229-
Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
230-
underflow.
231-
232-
`X` should be an iterator of real numbers. The result is computed using a single pass over
233-
the data.
234-
235-
# References
236-
237-
[Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
238-
"""
239-
logsumexp(X) = _logsumexp_onepass(X)
240-
241-
"""
242-
$(SIGNATURES)
243-
244-
Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids
245-
intermediate over- and underflow.
246-
247-
The result is computed using a single pass over the data.
248-
249-
# References
250-
251-
[Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
252-
"""
253-
logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims)
254-
255-
_logsumexp(X::AbstractArray{<:Real}, ::Colon) = _logsumexp_onepass(X)
256-
function _logsumexp(X::AbstractArray{<:Real}, dims)
257-
# Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)
258-
FT = float(eltype(X))
259-
xmax_r = reduce(_logsumexp_onepass_op, X; dims=dims, init=(FT(-Inf), zero(FT)))
260-
return @. first(xmax_r) + log1p(last(xmax_r))
261-
end
262-
263-
function _logsumexp_onepass(X)
264-
# fallback for empty collections
265-
isempty(X) && return log(sum(X))
266-
return _logsumexp_onepass_result(_logsumexp_onepass_reduce(X, Base.IteratorEltype(X)))
267-
end
268-
269-
# function barrier for reductions with single element and without initial element
270-
_logsumexp_onepass_result(x) = float(x)
271-
_logsumexp_onepass_result((xmax, r)::Tuple) = xmax + log1p(r)
272-
273-
# iterables with known element type
274-
function _logsumexp_onepass_reduce(X, ::Base.HasEltype)
275-
# do not perform type computations if element type is abstract
276-
T = eltype(X)
277-
isconcretetype(T) || return _logsumexp_onepass_reduce(X, Base.EltypeUnknown())
278-
279-
FT = float(T)
280-
return reduce(_logsumexp_onepass_op, X; init=(FT(-Inf), zero(FT)))
281-
end
282-
283-
# iterables without known element type
284-
_logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_op, X)
285-
286-
## Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced
287-
288-
# reduce two numbers
289-
function _logsumexp_onepass_op(x1, x2)
290-
a = x1 == x2 ? zero(x1 - x2) : -abs(x1 - x2)
291-
xmax = x1 > x2 ? oftype(a, x1) : oftype(a, x2)
292-
r = exp(a)
293-
return xmax, r
294-
end
295-
296-
# reduce a number and a partial sum
297-
function _logsumexp_onepass_op(x, (xmax, r)::Tuple)
298-
a = x == xmax ? zero(x - xmax) : -abs(x - xmax)
299-
if x > xmax
300-
_xmax = oftype(a, x)
301-
_r = (r + one(r)) * exp(a)
302-
else
303-
_xmax = oftype(a, xmax)
304-
_r = r + exp(a)
305-
end
306-
return _xmax, _r
307-
end
308-
_logsumexp_onepass_op(xmax_r::Tuple, x) = _logsumexp_onepass_op(x, xmax_r)
309-
310-
# reduce two partial sums
311-
function _logsumexp_onepass_op((xmax1, r1)::Tuple, (xmax2, r2)::Tuple)
312-
a = xmax1 == xmax2 ? zero(xmax1 - xmax2) : -abs(xmax1 - xmax2)
313-
if xmax1 > xmax2
314-
xmax = oftype(a, xmax1)
315-
r = r1 + (r2 + one(r2)) * exp(a)
316-
else
317-
xmax = oftype(a, xmax2)
318-
r = r2 + (r1 + one(r1)) * exp(a)
319-
end
320-
return xmax, r
321-
end
322-
323-
"""
324-
$(SIGNATURES)
325-
326229
Overwrite `r` with the `softmax` (or _normalized exponential_) transformation of `x`
327230
328231
That is, `r` is overwritten with `exp.(x)`, normalized to sum to 1.

src/logsumexp.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
$(SIGNATURES)
3+
4+
Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
5+
underflow.
6+
7+
`X` should be an iterator of real numbers. The result is computed using a single pass over
8+
the data.
9+
10+
# References
11+
12+
[Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
13+
"""
14+
logsumexp(X) = _logsumexp_onepass(X)
15+
16+
"""
17+
$(SIGNATURES)
18+
19+
Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids
20+
intermediate over- and underflow.
21+
22+
The result is computed using a single pass over the data.
23+
24+
# References
25+
26+
[Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
27+
"""
28+
logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims)
29+
30+
_logsumexp(X::AbstractArray{<:Real}, ::Colon) = _logsumexp_onepass(X)
31+
function _logsumexp(X::AbstractArray{<:Real}, dims)
32+
# Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)
33+
FT = float(eltype(X))
34+
xmax_r = reduce(_logsumexp_onepass_op, X; dims=dims, init=(FT(-Inf), zero(FT)))
35+
return @. first(xmax_r) + log1p(last(xmax_r))
36+
end
37+
38+
function _logsumexp_onepass(X)
39+
# fallback for empty collections
40+
isempty(X) && return log(sum(X))
41+
return _logsumexp_onepass_result(_logsumexp_onepass_reduce(X, Base.IteratorEltype(X)))
42+
end
43+
44+
# function barrier for reductions with single element and without initial element
45+
_logsumexp_onepass_result(x) = float(x)
46+
_logsumexp_onepass_result((xmax, r)::Tuple) = xmax + log1p(r)
47+
48+
# iterables with known element type
49+
function _logsumexp_onepass_reduce(X, ::Base.HasEltype)
50+
# do not perform type computations if element type is abstract
51+
T = eltype(X)
52+
isconcretetype(T) || return _logsumexp_onepass_reduce(X, Base.EltypeUnknown())
53+
54+
FT = float(T)
55+
return reduce(_logsumexp_onepass_op, X; init=(FT(-Inf), zero(FT)))
56+
end
57+
58+
# iterables without known element type
59+
_logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_op, X)
60+
61+
## Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced
62+
63+
# reduce two numbers
64+
function _logsumexp_onepass_op(x1, x2)
65+
a = x1 == x2 ? zero(x1 - x2) : -abs(x1 - x2)
66+
xmax = x1 > x2 ? oftype(a, x1) : oftype(a, x2)
67+
r = exp(a)
68+
return xmax, r
69+
end
70+
71+
# reduce a number and a partial sum
72+
function _logsumexp_onepass_op(x, (xmax, r)::Tuple)
73+
a = x == xmax ? zero(x - xmax) : -abs(x - xmax)
74+
if x > xmax
75+
_xmax = oftype(a, x)
76+
_r = (r + one(r)) * exp(a)
77+
else
78+
_xmax = oftype(a, xmax)
79+
_r = r + exp(a)
80+
end
81+
return _xmax, _r
82+
end
83+
_logsumexp_onepass_op(xmax_r::Tuple, x) = _logsumexp_onepass_op(x, xmax_r)
84+
85+
# reduce two partial sums
86+
function _logsumexp_onepass_op((xmax1, r1)::Tuple, (xmax2, r2)::Tuple)
87+
a = xmax1 == xmax2 ? zero(xmax1 - xmax2) : -abs(xmax1 - xmax2)
88+
if xmax1 > xmax2
89+
xmax = oftype(a, xmax1)
90+
r = r1 + (r2 + one(r2)) * exp(a)
91+
else
92+
xmax = oftype(a, xmax2)
93+
r = r2 + (r1 + one(r1)) * exp(a)
94+
end
95+
return xmax, r
96+
end

0 commit comments

Comments
 (0)