Skip to content

Commit e1fe1a4

Browse files
authored
Add onepass algorithm for logsumexp (#97)
1 parent 7d53d89 commit e1fe1a4

File tree

2 files changed

+102
-22
lines changed

2 files changed

+102
-22
lines changed

src/basicfuns.jl

Lines changed: 88 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ end
207207
Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling non-finite values.
208208
"""
209209
function logaddexp(x::Real, y::Real)
210-
# ensure Δ = 0 if x = y = Inf
210+
# ensure Δ = 0 if x = y = ± Inf
211211
Δ = ifelse(x == y, zero(x - y), abs(x - y))
212212
max(x, y) + log1pexp(-Δ)
213213
end
@@ -224,28 +224,99 @@ logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
224224
"""
225225
logsumexp(X)
226226
227-
Compute `log(sum(exp, X))`, evaluated avoiding intermediate overflow/undeflow.
227+
Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
228+
underflow.
229+
230+
`X` should be an iterator of real numbers. The result is computed using a single pass over
231+
the data.
232+
233+
# References
234+
235+
[Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
236+
"""
237+
logsumexp(X) = _logsumexp_onepass(X)
228238

229-
`X` should be an iterator of real numbers.
230239
"""
231-
function logsumexp(X)
240+
logsumexp(X::AbstractArray{<:Real}; dims=:)
241+
242+
Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids
243+
intermediate over- and underflow.
244+
245+
The result is computed using a single pass over the data.
246+
247+
# References
248+
249+
[Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
250+
"""
251+
logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims)
252+
253+
_logsumexp(X::AbstractArray{<:Real}, ::Colon) = _logsumexp_onepass(X)
254+
function _logsumexp(X::AbstractArray{<:Real}, dims)
255+
# Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)
256+
FT = float(eltype(X))
257+
xmax_r = reduce(_logsumexp_onepass_op, X; dims=dims, init=(FT(-Inf), zero(FT)))
258+
return @. first(xmax_r) + log1p(last(xmax_r))
259+
end
260+
261+
function _logsumexp_onepass(X)
262+
# fallback for empty collections
232263
isempty(X) && return log(sum(X))
233-
reduce(logaddexp, X)
264+
return _logsumexp_onepass_result(_logsumexp_onepass_reduce(X, Base.IteratorEltype(X)))
234265
end
235-
function logsumexp(X::AbstractArray{T}; dims=:) where {T<:Real}
236-
# Do not use log(zero(T)) directly to avoid issues with ForwardDiff (#82)
237-
u = reduce(max, X, dims=dims, init=oftype(log(zero(T)), -Inf))
238-
u isa AbstractArray || isfinite(u) || return float(u)
239-
let u=u # avoid https://github.com/JuliaLang/julia/issues/15276
240-
# TODO: remove the branch when JuliaLang/julia#31020 is merged.
241-
if u isa AbstractArray
242-
u .+ log.(sum(exp.(X .- u); dims=dims))
243-
else
244-
u + log(sum(x -> exp(x-u), X))
245-
end
246-
end
266+
267+
# function barrier for reductions with single element and without initial element
268+
_logsumexp_onepass_result(x) = float(x)
269+
_logsumexp_onepass_result((xmax, r)::Tuple) = xmax + log1p(r)
270+
271+
# iterables with known element type
272+
function _logsumexp_onepass_reduce(X, ::Base.HasEltype)
273+
# do not perform type computations if element type is abstract
274+
T = eltype(X)
275+
isconcretetype(T) || return _logsumexp_onepass_reduce(X, Base.EltypeUnknown())
276+
277+
FT = float(T)
278+
return reduce(_logsumexp_onepass_op, X; init=(FT(-Inf), zero(FT)))
279+
end
280+
281+
# iterables without known element type
282+
_logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_op, X)
283+
284+
## Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced
285+
286+
# reduce two numbers
287+
function _logsumexp_onepass_op(x1, x2)
288+
a = x1 == x2 ? zero(x1 - x2) : -abs(x1 - x2)
289+
xmax = x1 > x2 ? oftype(a, x1) : oftype(a, x2)
290+
r = exp(a)
291+
return xmax, r
247292
end
248293

294+
# reduce a number and a partial sum
295+
function _logsumexp_onepass_op(x, (xmax, r)::Tuple)
296+
a = x == xmax ? zero(x - xmax) : -abs(x - xmax)
297+
if x > xmax
298+
_xmax = oftype(a, x)
299+
_r = (r + one(r)) * exp(a)
300+
else
301+
_xmax = oftype(a, xmax)
302+
_r = r + exp(a)
303+
end
304+
return _xmax, _r
305+
end
306+
_logsumexp_onepass_op(xmax_r::Tuple, x) = _logsumexp_onepass_op(x, xmax_r)
307+
308+
# reduce two partial sums
309+
function _logsumexp_onepass_op((xmax1, r1)::Tuple, (xmax2, r2)::Tuple)
310+
a = xmax1 == xmax2 ? zero(xmax1 - xmax2) : -abs(xmax1 - xmax2)
311+
if xmax1 > xmax2
312+
xmax = oftype(a, xmax1)
313+
r = r1 + (r2 + one(r2)) * exp(a)
314+
else
315+
xmax = oftype(a, xmax2)
316+
r = r2 + (r1 + one(r1)) * exp(a)
317+
end
318+
return xmax, r
319+
end
249320

250321
"""
251322
softmax!(r::AbstractArray, x::AbstractArray)

test/basicfuns.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,18 @@ end
9696
@test logaddexp(2.0, 3.0) log(exp(2.0) + exp(3.0))
9797
@test logaddexp(10002, 10003) 10000 + logaddexp(2.0, 3.0)
9898

99-
@test logsumexp([1.0, 2.0, 3.0]) 3.40760596444438
100-
@test logsumexp((1.0, 2.0, 3.0)) 3.40760596444438
99+
@test @inferred(logsumexp([1.0])) == 1.0
100+
@test @inferred(logsumexp((x for x in [1.0]))) == 1.0
101+
@test @inferred(logsumexp([1.0, 2.0, 3.0])) 3.40760596444438
102+
@test @inferred(logsumexp((1.0, 2.0, 3.0))) 3.40760596444438
101103
@test logsumexp([1.0, 2.0, 3.0] .+ 1000.) 1003.40760596444438
102104

103-
@test logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=1) [3.40760596444438 1003.40760596444438]
104-
@test logsumexp([[1.0 2.0 3.0]; [1.0 2.0 3.0] .+ 1000.]; dims=2) [3.40760596444438, 1003.40760596444438]
105-
@test logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=[1,2]) [1003.4076059644444]
105+
@test @inferred(logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=1)) [3.40760596444438 1003.40760596444438]
106+
@test @inferred(logsumexp([[1.0 2.0 3.0]; [1.0 2.0 3.0] .+ 1000.]; dims=2)) [3.40760596444438, 1003.40760596444438]
107+
@test @inferred(logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=[1,2])) [1003.4076059644444]
108+
109+
# check underflow
110+
@test logsumexp([1e-20, log(1e-20)]) 2e-20
106111

107112
let cases = [([-Inf, -Inf], -Inf), # correct handling of all -Inf
108113
([-Inf, -Inf32], -Inf), # promotion
@@ -137,6 +142,10 @@ end
137142
@test isnan(logsumexp([NaN, 9.0]))
138143
@test isnan(logsumexp([NaN, Inf]))
139144
@test isnan(logsumexp([NaN, -Inf]))
145+
146+
# logsumexp with general iterables (issue #63)
147+
xs = range(-500, stop = 10, length = 1000)
148+
@test @inferred(logsumexp(x for x in xs)) == logsumexp(xs)
140149
end
141150

142151
@testset "softmax" begin

0 commit comments

Comments
 (0)