Skip to content

Commit 1e96ec8

Browse files
authored
Fix logsubexp(-Inf, -Inf) (#21)
1 parent e5cc905 commit 1e96ec8

File tree

3 files changed

+20
-11
lines changed

3 files changed

+20
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.2.4"
4+
version = "0.2.5"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/basicfuns.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,11 @@ $(SIGNATURES)
221221
222222
Return `log(abs(exp(x) - exp(y)))`, preserving numerical accuracy.
223223
"""
224-
logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
224+
function logsubexp(x::Real, y::Real)
225+
# ensure that `Δ = 0` if `x = y = - Inf` (but not for `x = y = +Inf`!)
226+
Δ = x == y && (isfinite(x) || x < 0) ? zero(x - y) : abs(x - y)
227+
return max(x, y) + log1mexp(-Δ)
228+
end
225229

226230
"""
227231
$(SIGNATURES)

test/basicfuns.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,21 +135,26 @@ end
135135
end
136136

137137
@test isnan(logsubexp(Inf, Inf))
138-
@test isnan(logsubexp(-Inf, -Inf))
138+
@test logsubexp(-Inf, -Inf) -Inf
139+
@test logsubexp(Inf, -Inf) Inf
140+
@test logsubexp(-Inf, Inf) Inf
139141
@test logsubexp(Inf, 9.0) Inf
142+
@test logsubexp(9.0, Inf) Inf
140143
@test logsubexp(-Inf, 9.0) 9.0
144+
@test logsubexp(9.0, -Inf) 9.0
141145
@test logsubexp(1f2, 1f2) -Inf32
142146
@test logsubexp(0, 0) -Inf
143147
@test logsubexp(3, 2) 2.541324854612918108978
144148

145149
# NaN propagation
146-
@test isnan(logaddexp(NaN, 9.0))
147-
@test isnan(logaddexp(NaN, Inf))
148-
@test isnan(logaddexp(NaN, -Inf))
149-
150-
@test isnan(logsubexp(NaN, 9.0))
151-
@test isnan(logsubexp(NaN, Inf))
152-
@test isnan(logsubexp(NaN, -Inf))
150+
for f in (logaddexp, logsubexp)
151+
@test isnan(f(NaN, 9.0))
152+
@test isnan(f(NaN, Inf))
153+
@test isnan(f(NaN, -Inf))
154+
@test isnan(f(9.0, NaN))
155+
@test isnan(f(Inf, NaN))
156+
@test isnan(f(-Inf, NaN))
157+
end
153158

154159
@test isnan(logsumexp([NaN, 9.0]))
155160
@test isnan(logsumexp([NaN, Inf]))
@@ -205,7 +210,7 @@ end
205210
s = softmax(x)
206211
@test s r
207212
@test eltype(s) === Float64
208-
213+
209214
# non-standard indices: #12
210215
x = OffsetArray(1:3, -2:0)
211216
s = softmax(x)

0 commit comments

Comments
 (0)