@@ -4,8 +4,8 @@ $(SIGNATURES)
4
4
Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
5
5
underflow.
6
6
7
- `X` should be an iterator of real numbers. The result is computed using a single pass over
8
- the data.
7
+ `X` should be an iterator of real or complex numbers. The result is computed using a single
8
+ pass over the data.
9
9
10
10
# References
11
11
@@ -25,10 +25,10 @@ The result is computed using a single pass over the data.
25
25
26
26
[Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
27
27
"""
28
- logsumexp (X:: AbstractArray{<:Real } ; dims= :) = _logsumexp (X, dims)
28
+ logsumexp (X:: AbstractArray{<:Number } ; dims= :) = _logsumexp (X, dims)
29
29
30
- _logsumexp (X:: AbstractArray{<:Real } , :: Colon ) = _logsumexp_onepass (X)
31
- function _logsumexp (X:: AbstractArray{<:Real } , dims)
30
+ _logsumexp (X:: AbstractArray{<:Number } , :: Colon ) = _logsumexp_onepass (X)
31
+ function _logsumexp (X:: AbstractArray{<:Number } , dims)
32
32
# Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)
33
33
FT = float (eltype (X))
34
34
xmax_r = reduce (_logsumexp_onepass_op, X; dims= dims, init= (FT (- Inf ), zero (FT)))
@@ -61,36 +61,68 @@ _logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_o
61
61
# # Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced
62
62
63
63
# reduce two numbers
64
- function _logsumexp_onepass_op (x1, x2)
65
- a = x1 == x2 ? zero (x1 - x2) : - abs (x1 - x2)
66
- xmax = x1 > x2 ? oftype (a, x1) : oftype (a, x2)
64
+ function _logsumexp_onepass_op (x1:: T , x2:: T ) where {T<: Number }
65
+ xmax, a = if x1 == x2
66
+ # handle `x1 = x2 = ±Inf` correctly
67
+ x2, zero (x1 - x2)
68
+ elseif isnan (x1) || isnan (x2)
69
+ # ensure that `NaN` is propagated correctly for complex numbers
70
+ z = oftype (x1, NaN )
71
+ z, exp (z)
72
+ elseif real (x1) > real (x2)
73
+ x1, x2 - x1
74
+ else
75
+ x2, x1 - x2
76
+ end
67
77
r = exp (a)
68
78
return xmax, r
69
79
end
80
+ _logsumexp_onepass_op (x1:: Number , x2:: Number ) = _logsumexp_onepass_op (promote (x1, x2)... )
70
81
71
82
# reduce a number and a partial sum
72
- function _logsumexp_onepass_op (x, (xmax, r):: Tuple )
73
- a = x == xmax ? zero (x - xmax) : - abs (x - xmax)
74
- if x > xmax
75
- _xmax = oftype (a, x)
76
- _r = (r + one (r)) * exp (a)
83
+ _logsumexp_onepass_op (x:: Number , (xmax, r):: Tuple{<:Number,<:Number} ) =
84
+ _logsumexp_onepass_op (x, xmax, r)
85
+ _logsumexp_onepass_op ((xmax, r):: Tuple{<:Number,<:Number} , x:: Number ) =
86
+ _logsumexp_onepass_op (x, xmax, r)
87
+ _logsumexp_onepass_op (x:: Number , xmax:: Number , r:: Number ) =
88
+ _logsumexp_onepass_op (promote (x, xmax)... , r)
89
+ function _logsumexp_onepass_op (x:: T , xmax:: T , r:: Number ) where {T<: Number }
90
+ _xmax, _r = if x == xmax
91
+ # handle `x = xmax = ±Inf` correctly
92
+ xmax, r + exp (zero (x - xmax))
93
+ elseif isnan (x) || isnan (xmax)
94
+ # ensure that `NaN` is propagated correctly for complex numbers
95
+ z = oftype (x, NaN )
96
+ z, r + exp (z)
97
+ elseif real (x) > real (xmax)
98
+ x, (r + one (r)) * exp (xmax - x)
77
99
else
78
- _xmax = oftype (a, xmax)
79
- _r = r + exp (a)
100
+ xmax, r + exp (x - xmax)
80
101
end
81
102
return _xmax, _r
82
103
end
83
- _logsumexp_onepass_op (xmax_r:: Tuple , x) = _logsumexp_onepass_op (x, xmax_r)
84
104
85
105
# reduce two partial sums
86
- function _logsumexp_onepass_op ((xmax1, r1):: Tuple , (xmax2, r2):: Tuple )
87
- a = xmax1 == xmax2 ? zero (xmax1 - xmax2) : - abs (xmax1 - xmax2)
88
- if xmax1 > xmax2
89
- xmax = oftype (a, xmax1)
90
- r = r1 + (r2 + one (r2)) * exp (a)
106
+ function _logsumexp_onepass_op (
107
+ (xmax1, r1):: Tuple{<:Number,<:Number} , (xmax2, r2):: Tuple{<:Number,<:Number}
108
+ )
109
+ return _logsumexp_onepass_op (xmax1, xmax2, r1, r2)
110
+ end
111
+ function _logsumexp_onepass_op (xmax1:: Number , xmax2:: Number , r1:: Number , r2:: Number )
112
+ return _logsumexp_onepass_op (promote (xmax1, xmax2)... , promote (r1, r2)... )
113
+ end
114
+ function _logsumexp_onepass_op (xmax1:: T , xmax2:: T , r1:: R , r2:: R ) where {T<: Number ,R<: Number }
115
+ xmax, r = if xmax1 == xmax2
116
+ # handle `xmax1 = xmax2 = ±Inf` correctly
117
+ xmax2, r2 + (r1 + one (r1)) * exp (zero (xmax1 - xmax2))
118
+ elseif isnan (xmax1) || isnan (xmax2)
119
+ # ensure that `NaN` is propagated correctly for complex numbers
120
+ z = oftype (xmax1, NaN )
121
+ z, r1 + exp (z)
122
+ elseif real (xmax1) > real (xmax2)
123
+ xmax1, r1 + (r2 + one (r2)) * exp (xmax2 - xmax1)
91
124
else
92
- xmax = oftype (a, xmax2)
93
- r = r2 + (r1 + one (r1)) * exp (a)
125
+ xmax2, r2 + (r1 + one (r1)) * exp (xmax1 - xmax2)
94
126
end
95
127
return xmax, r
96
128
end
0 commit comments