Skip to content

Commit 59b8c09

Browse files
authored
Fix issues with ForwardDiff 0.10.33 (#60)
1 parent 956de31 commit 59b8c09

File tree

4 files changed

+53
-23
lines changed

4 files changed

+53
-23
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.18"
4+
version = "0.3.19"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -21,10 +21,11 @@ julia = "1"
2121

2222
[extras]
2323
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
24+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2425
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2526
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2627
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2728
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2829

2930
[targets]
30-
test = ["ChainRulesTestUtils", "FiniteDifferences", "OffsetArrays", "Random", "Test"]
31+
test = ["ChainRulesTestUtils", "FiniteDifferences", "ForwardDiff", "OffsetArrays", "Random", "Test"]

src/logsumexp.jl

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,22 @@ _logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_o
8484

8585
# reduce two numbers
8686
function _logsumexp_onepass_op(x1::T, x2::T) where {T<:Number}
87-
xmax, a = if x1 == x2
88-
# handle `x1 = x2 = ±Inf` correctly
89-
x2, zero(x1 - x2)
90-
elseif isnan(x1) || isnan(x2)
87+
xmax, a = if isnan(x1) || isnan(x2)
9188
# ensure that `NaN` is propagated correctly for complex numbers
9289
z = oftype(x1, NaN)
9390
z, exp(z)
94-
elseif real(x1) > real(x2)
95-
x1, x2 - x1
9691
else
97-
x2, x1 - x2
92+
real_x1 = real(x1)
93+
real_x2 = real(x2)
94+
if real_x1 > real_x2
95+
x1, x2 - x1
96+
elseif real_x1 < real_x2
97+
x2, x1 - x2
98+
else
99+
# handle `x1 = x2 = ±Inf` correctly
100+
# checking inequalities above instead of equality fixes issue #59
101+
x2, zero(x1 - x2)
102+
end
98103
end
99104
r = exp(a)
100105
return xmax, r
@@ -109,17 +114,22 @@ _logsumexp_onepass_op((xmax, r)::Tuple{<:Number,<:Number}, x::Number) =
109114
_logsumexp_onepass_op(x::Number, xmax::Number, r::Number) =
110115
_logsumexp_onepass_op(promote(x, xmax)..., r)
111116
function _logsumexp_onepass_op(x::T, xmax::T, r::Number) where {T<:Number}
112-
_xmax, _r = if x == xmax
113-
# handle `x = xmax = ±Inf` correctly
114-
xmax, r + exp(zero(x - xmax))
115-
elseif isnan(x) || isnan(xmax)
117+
_xmax, _r = if isnan(x) || isnan(xmax)
116118
# ensure that `NaN` is propagated correctly for complex numbers
117119
z = oftype(x, NaN)
118120
z, r + exp(z)
119-
elseif real(x) > real(xmax)
120-
x, (r + one(r)) * exp(xmax - x)
121121
else
122-
xmax, r + exp(x - xmax)
122+
real_x = real(x)
123+
real_xmax = real(xmax)
124+
if real_x > real_xmax
125+
x, (r + one(r)) * exp(xmax - x)
126+
elseif real_x < real_xmax
127+
xmax, r + exp(x - xmax)
128+
else
129+
# handle `x = xmax = ±Inf` correctly
130+
# checking inequalities above instead of equality fixes issue #59
131+
xmax, r + exp(zero(x - xmax))
132+
end
123133
end
124134
return _xmax, _r
125135
end
@@ -134,17 +144,22 @@ function _logsumexp_onepass_op(xmax1::Number, xmax2::Number, r1::Number, r2::Num
134144
return _logsumexp_onepass_op(promote(xmax1, xmax2)..., promote(r1, r2)...)
135145
end
136146
function _logsumexp_onepass_op(xmax1::T, xmax2::T, r1::R, r2::R) where {T<:Number,R<:Number}
137-
xmax, r = if xmax1 == xmax2
138-
# handle `xmax1 = xmax2 = ±Inf` correctly
139-
xmax2, r2 + (r1 + one(r1)) * exp(zero(xmax1 - xmax2))
140-
elseif isnan(xmax1) || isnan(xmax2)
147+
xmax, r = if isnan(xmax1) || isnan(xmax2)
141148
# ensure that `NaN` is propagated correctly for complex numbers
142149
z = oftype(xmax1, NaN)
143150
z, r1 + exp(z)
144-
elseif real(xmax1) > real(xmax2)
145-
xmax1, r1 + (r2 + one(r2)) * exp(xmax2 - xmax1)
146151
else
147-
xmax2, r2 + (r1 + one(r1)) * exp(xmax1 - xmax2)
152+
real_xmax1 = real(xmax1)
153+
real_xmax2 = real(xmax2)
154+
if real_xmax1 > real_xmax2
155+
xmax1, r1 + (r2 + one(r2)) * exp(xmax2 - xmax1)
156+
elseif real_xmax1 < real_xmax2
157+
xmax2, r2 + (r1 + one(r1)) * exp(xmax1 - xmax2)
158+
else
159+
# handle `xmax1 = xmax2 = ±Inf` correctly
160+
# checking inequalities above instead of equality fixes issue #59
161+
xmax2, r2 + (r1 + one(r1)) * exp(zero(xmax1 - xmax2))
162+
end
148163
end
149164
return xmax, r
150165
end

test/basicfuns.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,19 @@ end
342342
expected = logsumexp(xs; dims=2)
343343
@test logsumexp!(out, xs) expected
344344
@test out expected
345+
346+
@testset "ForwardDiff" begin
347+
# vector with finite numbers
348+
x = randn(10)
349+
∇x = unthunk(rrule(logsumexp, x)[2](1)[2])
350+
@test ForwardDiff.gradient(logsumexp, x) ∇x
351+
352+
# issue #59
353+
x = vcat(-Inf, randn(9))
354+
∇x = unthunk(rrule(logsumexp, x)[2](1)[2])
355+
@assert all(isfinite, ∇x)
356+
@test ForwardDiff.gradient(logsumexp, x) ∇x
357+
end
345358
end
346359

347360
@testset "softmax" begin

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ChainRulesTestUtils
33
using ChainRulesCore
44
using ChangesOfVariables
55
using FiniteDifferences
6+
using ForwardDiff
67
using InverseFunctions
78
using OffsetArrays
89

0 commit comments

Comments
 (0)