Skip to content

Commit 35fadbf

Browse files
authored
Generalize logsumexp to complex numbers (#19)
1 parent 9e5de81 commit 35fadbf

File tree

3 files changed

+94
-32
lines changed

3 files changed

+94
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.2.3"
4+
version = "0.2.4"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/logsumexp.jl

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ $(SIGNATURES)
44
Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
55
underflow.
66
7-
`X` should be an iterator of real numbers. The result is computed using a single pass over
8-
the data.
7+
`X` should be an iterator of real or complex numbers. The result is computed using a single
8+
pass over the data.
99
1010
# References
1111
@@ -25,10 +25,10 @@ The result is computed using a single pass over the data.
2525
2626
[Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
2727
"""
28-
logsumexp(X::AbstractArray{<:Real}; dims=:) = _logsumexp(X, dims)
28+
logsumexp(X::AbstractArray{<:Number}; dims=:) = _logsumexp(X, dims)
2929

30-
_logsumexp(X::AbstractArray{<:Real}, ::Colon) = _logsumexp_onepass(X)
31-
function _logsumexp(X::AbstractArray{<:Real}, dims)
30+
_logsumexp(X::AbstractArray{<:Number}, ::Colon) = _logsumexp_onepass(X)
31+
function _logsumexp(X::AbstractArray{<:Number}, dims)
3232
# Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)
3333
FT = float(eltype(X))
3434
xmax_r = reduce(_logsumexp_onepass_op, X; dims=dims, init=(FT(-Inf), zero(FT)))
@@ -61,36 +61,68 @@ _logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_o
6161
## Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced
6262

6363
# reduce two numbers
64-
function _logsumexp_onepass_op(x1, x2)
65-
a = x1 == x2 ? zero(x1 - x2) : -abs(x1 - x2)
66-
xmax = x1 > x2 ? oftype(a, x1) : oftype(a, x2)
64+
function _logsumexp_onepass_op(x1::T, x2::T) where {T<:Number}
65+
xmax, a = if x1 == x2
66+
# handle `x1 = x2 = ±Inf` correctly
67+
x2, zero(x1 - x2)
68+
elseif isnan(x1) || isnan(x2)
69+
# ensure that `NaN` is propagated correctly for complex numbers
70+
z = oftype(x1, NaN)
71+
z, exp(z)
72+
elseif real(x1) > real(x2)
73+
x1, x2 - x1
74+
else
75+
x2, x1 - x2
76+
end
6777
r = exp(a)
6878
return xmax, r
6979
end
80+
_logsumexp_onepass_op(x1::Number, x2::Number) = _logsumexp_onepass_op(promote(x1, x2)...)
7081

7182
# reduce a number and a partial sum
72-
function _logsumexp_onepass_op(x, (xmax, r)::Tuple)
73-
a = x == xmax ? zero(x - xmax) : -abs(x - xmax)
74-
if x > xmax
75-
_xmax = oftype(a, x)
76-
_r = (r + one(r)) * exp(a)
83+
_logsumexp_onepass_op(x::Number, (xmax, r)::Tuple{<:Number,<:Number}) =
84+
_logsumexp_onepass_op(x, xmax, r)
85+
_logsumexp_onepass_op((xmax, r)::Tuple{<:Number,<:Number}, x::Number) =
86+
_logsumexp_onepass_op(x, xmax, r)
87+
_logsumexp_onepass_op(x::Number, xmax::Number, r::Number) =
88+
_logsumexp_onepass_op(promote(x, xmax)..., r)
89+
function _logsumexp_onepass_op(x::T, xmax::T, r::Number) where {T<:Number}
90+
_xmax, _r = if x == xmax
91+
# handle `x = xmax = ±Inf` correctly
92+
xmax, r + exp(zero(x - xmax))
93+
elseif isnan(x) || isnan(xmax)
94+
# ensure that `NaN` is propagated correctly for complex numbers
95+
z = oftype(x, NaN)
96+
z, r + exp(z)
97+
elseif real(x) > real(xmax)
98+
x, (r + one(r)) * exp(xmax - x)
7799
else
78-
_xmax = oftype(a, xmax)
79-
_r = r + exp(a)
100+
xmax, r + exp(x - xmax)
80101
end
81102
return _xmax, _r
82103
end
83-
_logsumexp_onepass_op(xmax_r::Tuple, x) = _logsumexp_onepass_op(x, xmax_r)
84104

85105
# reduce two partial sums
86-
function _logsumexp_onepass_op((xmax1, r1)::Tuple, (xmax2, r2)::Tuple)
87-
a = xmax1 == xmax2 ? zero(xmax1 - xmax2) : -abs(xmax1 - xmax2)
88-
if xmax1 > xmax2
89-
xmax = oftype(a, xmax1)
90-
r = r1 + (r2 + one(r2)) * exp(a)
106+
function _logsumexp_onepass_op(
107+
(xmax1, r1)::Tuple{<:Number,<:Number}, (xmax2, r2)::Tuple{<:Number,<:Number}
108+
)
109+
return _logsumexp_onepass_op(xmax1, xmax2, r1, r2)
110+
end
111+
function _logsumexp_onepass_op(xmax1::Number, xmax2::Number, r1::Number, r2::Number)
112+
return _logsumexp_onepass_op(promote(xmax1, xmax2)..., promote(r1, r2)...)
113+
end
114+
function _logsumexp_onepass_op(xmax1::T, xmax2::T, r1::R, r2::R) where {T<:Number,R<:Number}
115+
xmax, r = if xmax1 == xmax2
116+
# handle `xmax1 = xmax2 = ±Inf` correctly
117+
xmax2, r2 + (r1 + one(r1)) * exp(zero(xmax1 - xmax2))
118+
elseif isnan(xmax1) || isnan(xmax2)
119+
# ensure that `NaN` is propagated correctly for complex numbers
120+
z = oftype(xmax1, NaN)
121+
z, r1 + exp(z)
122+
elseif real(xmax1) > real(xmax2)
123+
xmax1, r1 + (r2 + one(r2)) * exp(xmax2 - xmax1)
91124
else
92-
xmax = oftype(a, xmax2)
93-
r = r2 + (r1 + one(r1)) * exp(a)
125+
xmax2, r2 + (r1 + one(r1)) * exp(xmax1 - xmax2)
94126
end
95127
return xmax, r
96128
end

test/basicfuns.jl

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,31 @@ end
9494
@test logaddexp(2.0, 3.0) log(exp(2.0) + exp(3.0))
9595
@test logaddexp(10002, 10003) 10000 + logaddexp(2.0, 3.0)
9696

97-
@test @inferred(logsumexp([1.0])) == 1.0
98-
@test @inferred(logsumexp((x for x in [1.0]))) == 1.0
99-
@test @inferred(logsumexp([1.0, 2.0, 3.0])) 3.40760596444438
100-
@test @inferred(logsumexp((1.0, 2.0, 3.0))) 3.40760596444438
101-
@test logsumexp([1.0, 2.0, 3.0] .+ 1000.) 1003.40760596444438
97+
for x in ([1.0], Complex{Float64}[1.0])
98+
@test @inferred(logsumexp(x)) == 1.0
99+
@test @inferred(logsumexp((xi for xi in x))) == 1.0
100+
end
102101

103-
@test @inferred(logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=1)) [3.40760596444438 1003.40760596444438]
104-
@test @inferred(logsumexp([[1.0 2.0 3.0]; [1.0 2.0 3.0] .+ 1000.]; dims=2)) [3.40760596444438, 1003.40760596444438]
105-
@test @inferred(logsumexp([[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]; dims=[1,2])) [1003.4076059644444]
102+
for x in ([1.0, 2.0, 3.0], Complex{Float64}[1.0, 2.0, 3.0])
103+
@test @inferred(logsumexp(x)) 3.40760596444438
104+
@test logsumexp(x .+ 1000) 1003.40760596444438
105+
end
106+
107+
for x in ((1.0, 2.0, 3.0), map(complex, (1.0, 2.0, 3.0)))
108+
@test @inferred(logsumexp(x)) 3.40760596444438
109+
end
110+
111+
_x = [[1.0, 2.0, 3.0] [1.0, 2.0, 3.0] .+ 1000.]
112+
for x in (_x, complex(_x))
113+
@test @inferred(logsumexp(x; dims=1)) [3.40760596444438 1003.40760596444438]
114+
@test @inferred(logsumexp(x; dims=[1, 2])) [1003.4076059644444]
115+
y = copy(x')
116+
@test @inferred(logsumexp(y; dims=2)) [3.40760596444438, 1003.40760596444438]
117+
end
106118

107119
# check underflow
108120
@test logsumexp([1e-20, log(1e-20)]) 2e-20
121+
@test logsumexp(Complex{Float64}[1e-20, log(1e-20)]) 2e-20
109122

110123
let cases = [([-Inf, -Inf], -Inf), # correct handling of all -Inf
111124
([-Inf, -Inf32], -Inf), # promotion
@@ -117,6 +130,7 @@ end
117130
for (arguments, result) in cases
118131
@test logaddexp(arguments...) result
119132
@test logsumexp(arguments) result
133+
@test logsumexp(complex(arguments)) complex(result)
120134
end
121135
end
122136

@@ -140,10 +154,26 @@ end
140154
@test isnan(logsumexp([NaN, 9.0]))
141155
@test isnan(logsumexp([NaN, Inf]))
142156
@test isnan(logsumexp([NaN, -Inf]))
157+
@test isnan(logsumexp(Complex{Float64}[NaN, 9.0]))
158+
@test isnan(logsumexp(Complex{Float64}[NaN, Inf]))
159+
@test isnan(logsumexp(Complex{Float64}[NaN, -Inf]))
160+
@test isnan(logsumexp(Complex{Float64}[NaN * im, 9.0]))
161+
@test isnan(logsumexp(Complex{Float64}[NaN * im, Inf]))
162+
@test isnan(logsumexp(Complex{Float64}[NaN * im, -Inf]))
143163

144164
# logsumexp with general iterables (issue #63)
145165
xs = range(-500, stop = 10, length = 1000)
146166
@test @inferred(logsumexp(x for x in xs)) == logsumexp(xs)
167+
xs = range(-500 + 0.5im, stop = 10 + 30im, length = 1000)
168+
@test @inferred(logsumexp(x for x in xs)) == logsumexp(xs)
169+
170+
# complex numbers
171+
xs = randn(Complex{Float64}, 10, 5)
172+
@test @inferred(logsumexp(xs)) log(sum(exp.(xs)))
173+
@test @inferred(logsumexp(xs; dims=1)) log.(sum(exp.(xs); dims=1))
174+
@test @inferred(logsumexp(xs; dims=2)) log.(sum(exp.(xs); dims=2))
175+
@test @inferred(logsumexp(xs; dims=[1, 2])) log(sum(exp.(xs); dims=[1, 2]))
176+
@test @inferred(logsumexp(x for x in xs)) == logsumexp(xs)
147177
end
148178

149179
@testset "softmax" begin

0 commit comments

Comments
 (0)