Skip to content

Commit 239b2ad

Browse files
authored
Add logsumexp! (#39)
* Add `logsumexp!` * Add some tests * Update docstring * Fix test error * Fix test * More test fixes * Add more tests * Update docstring * Update docstrings
1 parent 4e50c54 commit 239b2ad

File tree

5 files changed

+69
-12
lines changed

5 files changed

+69
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.8"
4+
version = "0.3.9"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ logmxp1
2525
logaddexp
2626
logsubexp
2727
logsumexp
28+
logsumexp!
2829
softmax!
2930
softmax
3031
```

src/LogExpFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import IrrationalConstants
1010
import LinearAlgebra
1111

1212
export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
13-
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax,
13+
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax,
1414
softmax!, logcosh
1515

1616
include("basicfuns.jl")

src/logsumexp.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""
22
$(SIGNATURES)
33
4-
Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
5-
underflow.
4+
Compute `log(sum(exp, X))`.
65
7-
`X` should be an iterator of real or complex numbers. The result is computed using a single
8-
pass over the data.
6+
`X` should be an iterator of real or complex numbers.
7+
The result is computed in a numerically stable way that avoids intermediate over- and underflow, using a single pass over the data.
8+
9+
See also [`logsumexp!`](@ref).
910
1011
# References
1112
@@ -16,17 +17,38 @@ logsumexp(X) = _logsumexp_onepass(X)
1617
"""
1718
$(SIGNATURES)
1819
19-
Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids
20-
intermediate over- and underflow.
20+
Compute `log.(sum(exp.(X); dims=dims))`.
21+
22+
The result is computed in a numerically stable way that avoids intermediate over- and underflow, using a single pass over the data.
2123
22-
The result is computed using a single pass over the data.
24+
See also [`logsumexp!`](@ref).
2325
2426
# References
2527
2628
[Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
2729
"""
2830
logsumexp(X::AbstractArray{<:Number}; dims=:) = _logsumexp(X, dims)
2931

32+
"""
33+
$(SIGNATURES)
34+
35+
Compute [`logsumexp`](@ref) of `X` over the singleton dimensions of `out`, and write results to `out`.
36+
37+
The result is computed in a numerically stable way that avoids intermediate over- and underflow, using a single pass over the data.
38+
39+
See also [`logsumexp`](@ref).
40+
41+
# References
42+
43+
[Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
44+
"""
45+
function logsumexp!(out::AbstractArray{<:Number}, X::AbstractArray{<:Number})
46+
FT = eltype(out)
47+
xmax_r = fill!(similar(out, Tuple{FT,FT}), (FT(-Inf), zero(FT)))
48+
Base.reducedim!(_logsumexp_onepass_op, xmax_r, X)
49+
return @. out = first(xmax_r) + log1p(last(xmax_r))
50+
end
51+
3052
_logsumexp(X::AbstractArray{<:Number}, ::Colon) = _logsumexp_onepass(X)
3153
function _logsumexp(X::AbstractArray{<:Number}, dims)
3254
# Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)

test/basicfuns.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,31 @@ end
195195

196196
_x = [[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]
197197
for x in (_x, complex(_x))
198-
@test @inferred(logsumexp(x; dims=1)) [3.40760596444438 1003.40760596444438]
199-
@test @inferred(logsumexp(x; dims=[1, 2])) [1003.4076059644444]
198+
expected = [3.40760596444438 1003.40760596444438]
199+
@test @inferred(logsumexp(x; dims=1)) expected
200+
out = Array{eltype(x)}(undef, 1, 2)
201+
@test @inferred(logsumexp!(out, x)) expected
202+
@test out expected
203+
200204
y = copy(x')
201-
@test @inferred(logsumexp(y; dims=2)) [3.40760596444438, 1003.40760596444438]
205+
expected = [3.40760596444438, 1003.40760596444438]
206+
@test @inferred(logsumexp(y; dims=2)) expected
207+
out = Array{eltype(y)}(undef, 2)
208+
@test @inferred(logsumexp!(out, y)) expected
209+
@test out expected
210+
211+
expected = [1003.4076059644444]
212+
@test @inferred(logsumexp(x; dims=[1, 2])) expected
213+
out = Array{eltype(x)}(undef, 1)
214+
@test @inferred(logsumexp!(out, x)) expected
215+
@test out expected
202216
end
203217

204218
# check underflow
205219
@test logsumexp([1e-20, log(1e-20)]) 2e-20
206220
@test logsumexp(Complex{Float64}[1e-20, log(1e-20)]) 2e-20
221+
@test logsumexp!([1.0], [1e-20, log(1e-20)]) [2e-20]
222+
@test logsumexp!(Complex{Float64}[1.0], Complex{Float64}[1e-20, log(1e-20)]) [2e-20]
207223

208224
let cases = [([-Inf, -Inf], -Inf), # correct handling of all -Inf
209225
([-Inf, -Inf32], -Inf), # promotion
@@ -216,6 +232,15 @@ end
216232
@test logaddexp(arguments...) result
217233
@test logsumexp(arguments) result
218234
@test logsumexp(complex(arguments)) complex(result)
235+
236+
FT = float(eltype(arguments))
237+
out = [one(FT)]
238+
@test logsumexp!(out, arguments)[1] result
239+
@test out[1] result
240+
241+
out = [one(complex(FT))]
242+
@test logsumexp!(out, complex(arguments))[1] complex(result)
243+
@test out[1] complex(result)
219244
end
220245
end
221246

@@ -250,6 +275,15 @@ end
250275
@test isnan(logsumexp(Complex{Float64}[NaN * im, 9.0]))
251276
@test isnan(logsumexp(Complex{Float64}[NaN * im, Inf]))
252277
@test isnan(logsumexp(Complex{Float64}[NaN * im, -Inf]))
278+
@test isnan(logsumexp!([1.0], [NaN, 9.0])[1])
279+
@test isnan(logsumexp!([1.0], [NaN, Inf])[1])
280+
@test isnan(logsumexp!([1.0], [NaN, -Inf])[1])
281+
@test isnan(logsumexp!(Complex{Float64}[1.0], Complex{Float64}[NaN, 9.0])[1])
282+
@test isnan(logsumexp!(Complex{Float64}[1.0], Complex{Float64}[NaN, Inf])[1])
283+
@test isnan(logsumexp!(Complex{Float64}[1.0], Complex{Float64}[NaN, -Inf])[1])
284+
@test isnan(logsumexp!(Complex{Float64}[1.0], Complex{Float64}[NaN * im, 9.0])[1])
285+
@test isnan(logsumexp!(Complex{Float64}[1.0], Complex{Float64}[NaN * im, Inf])[1])
286+
@test isnan(logsumexp!(Complex{Float64}[1.0], Complex{Float64}[NaN * im, -Inf])[1])
253287

254288
# logsumexp with general iterables (issue #63)
255289
xs = range(-500, stop = 10, length = 1000)

0 commit comments

Comments
 (0)