Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,5 @@ ChangesOfVariables = "0.1"
DocStringExtensions = "0.8, 0.9"
InverseFunctions = "0.1"
IrrationalConstants = "0.1, 0.2"
LinearAlgebra = "1.10"
julia = "1.10"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesCore", "ChainRulesTestUtils", "ChangesOfVariables", "FiniteDifferences", "ForwardDiff", "InverseFunctions", "OffsetArrays", "Random", "Test"]
14 changes: 8 additions & 6 deletions src/logsumexp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,23 @@ _logsumexp_onepass_op((xmax, r)::Tuple{<:Number,<:Number}, x::Number) =
_logsumexp_onepass_op(x::Number, xmax::Number, r::Number) =
_logsumexp_onepass_op(promote(x, xmax)..., r)
function _logsumexp_onepass_op(x::T, xmax::T, r::Number) where {T<:Number}
# The following invariants are maintained through the reduction:
# `xmax` is the maximum of the elements encountered so far,
# ``r = ∑ᵢ exp(xᵢ - xmax)`` over the same elements.
_xmax, _r = if isnan(x) || isnan(xmax)
# ensure that `NaN` is propagated correctly for complex numbers
z = oftype(x, NaN)
z, r + exp(z)
else
real_x = real(x)
real_xmax = real(xmax)
if real_x > real_xmax
if isinf(real_x) && isinf(real_xmax) && (real_x * real_xmax) > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change should be benchmarked. IIRC the performance of logsumexp was very sensitive to even minor changes of the implementation. To avoid the multiplication one might want to use

Suggested change
if isinf(real_x) && isinf(real_xmax) && (real_x * real_xmax) > 0
if isinf(real_x) && isinf(real_xmax) && sign(real_x) == sign(real_xmax)

# handle `x = xmax = ±Inf` correctly, without relying on ForwardDiff.value
xmax, r + exp(zero(x - xmax))
elseif real_x > real_xmax
x, (r + one(r)) * exp(xmax - x)
elseif real_x < real_xmax
xmax, r + exp(x - xmax)
else
# handle `x = xmax = ±Inf` correctly
# checking inequalities above instead of equality fixes issue #59
xmax, r + exp(zero(x - xmax))
xmax, r + exp(x - xmax)
end
end
return _xmax, _r
Expand Down
13 changes: 13 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Comment on lines +1 to +13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO test dependencies should also be specified with compat entries and this file should be updated with CompatHelper.

6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,9 @@ include("basicfuns.jl")
include("chainrules.jl")
include("inverse.jl")
include("with_logabsdet_jacobian.jl")

# QA
import JET
JET.report_package("LogExpFunctions")
import Aqua
Aqua.test_all(LogExpFunctions)