@@ -114,21 +114,23 @@ _logsumexp_onepass_op((xmax, r)::Tuple{<:Number,<:Number}, x::Number) =
114
114
_logsumexp_onepass_op (x:: Number , xmax:: Number , r:: Number ) =
115
115
_logsumexp_onepass_op (promote (x, xmax)... , r)
116
116
function _logsumexp_onepass_op (x:: T , xmax:: T , r:: Number ) where {T<: Number }
117
+ # The following invariants are maintained through the reduction:
118
+ # `xmax` is the maximum of the elements encountered so far,
119
+ # ``r = ∑ᵢ exp(xᵢ - xmax)`` over the same elements.
117
120
_xmax, _r = if isnan (x) || isnan (xmax)
118
121
# ensure that `NaN` is propagated correctly for complex numbers
119
122
z = oftype (x, NaN )
120
123
z, r + exp (z)
121
124
else
122
125
real_x = real (x)
123
126
real_xmax = real (xmax)
124
- if real_x > real_xmax
127
+ if isinf (real_x) && isinf (real_xmax) && (real_x * real_xmax) > 0
128
+ # handle `x = xmax = ±Inf` correctly, without relying on ForwardDiff.value
129
+ xmax, r + exp (zero (x - xmax))
130
+ elseif real_x > real_xmax
125
131
x, (r + one (r)) * exp (xmax - x)
126
- elseif real_x < real_xmax
127
- xmax, r + exp (x - xmax)
128
132
else
129
- # handle `x = xmax = ±Inf` correctly
130
- # checking inequalities above instead of equality fixes issue #59
131
- xmax, r + exp (zero (x - xmax))
133
+ xmax, r + exp (x - xmax)
132
134
end
133
135
end
134
136
return _xmax, _r
0 commit comments