Skip to content

Commit 1961420

Browse files
committed
rearrange branches and test early for +-Inf case
This should work for ForwardDiff, too.
1 parent 6422129 commit 1961420

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/logsumexp.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,21 +114,23 @@ _logsumexp_onepass_op((xmax, r)::Tuple{<:Number,<:Number}, x::Number) =
114114
_logsumexp_onepass_op(x::Number, xmax::Number, r::Number) =
115115
_logsumexp_onepass_op(promote(x, xmax)..., r)
116116
function _logsumexp_onepass_op(x::T, xmax::T, r::Number) where {T<:Number}
117+
# The following invariants are maintained through the reduction:
118+
# `xmax` is the maximum of the elements encountered so far,
119+
# ``r = ∑ᵢ exp(xᵢ - xmax)`` over the same elements.
117120
_xmax, _r = if isnan(x) || isnan(xmax)
118121
# ensure that `NaN` is propagated correctly for complex numbers
119122
z = oftype(x, NaN)
120123
z, r + exp(z)
121124
else
122125
real_x = real(x)
123126
real_xmax = real(xmax)
124-
if real_x > real_xmax
127+
if isinf(real_x) && isinf(real_xmax) && (real_x * real_xmax) > 0
128+
# handle `x = xmax = ±Inf` correctly, without relying on ForwardDiff.value
129+
xmax, r + exp(zero(x - xmax))
130+
elseif real_x > real_xmax
125131
x, (r + one(r)) * exp(xmax - x)
126-
elseif real_x < real_xmax
127-
xmax, r + exp(x - xmax)
128132
else
129-
# handle `x = xmax = ±Inf` correctly
130-
# checking inequalities above instead of equality fixes issue #59
131-
xmax, r + exp(zero(x - xmax))
133+
xmax, r + exp(x - xmax)
132134
end
133135
end
134136
return _xmax, _r

0 commit comments

Comments
 (0)