diff --git a/Project.toml b/Project.toml index 2c843af..17f3aab 100644 --- a/Project.toml +++ b/Project.toml @@ -27,18 +27,5 @@ ChangesOfVariables = "0.1" DocStringExtensions = "0.8, 0.9" InverseFunctions = "0.1" IrrationalConstants = "0.1, 0.2" +LinearAlgebra = "1.10" julia = "1.10" - -[extras] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" -OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["ChainRulesCore", "ChainRulesTestUtils", "ChangesOfVariables", "FiniteDifferences", "ForwardDiff", "InverseFunctions", "OffsetArrays", "Random", "Test"] diff --git a/src/logsumexp.jl b/src/logsumexp.jl index 0540cca..0c655f4 100644 --- a/src/logsumexp.jl +++ b/src/logsumexp.jl @@ -114,6 +114,9 @@ _logsumexp_onepass_op((xmax, r)::Tuple{<:Number,<:Number}, x::Number) = _logsumexp_onepass_op(x::Number, xmax::Number, r::Number) = _logsumexp_onepass_op(promote(x, xmax)..., r) function _logsumexp_onepass_op(x::T, xmax::T, r::Number) where {T<:Number} + # The following invariants are maintained through the reduction: + # `xmax` is the maximum of the elements encountered so far, + # ``r = ∑ᵢ exp(xᵢ - xmax)`` over the same elements. _xmax, _r = if isnan(x) || isnan(xmax) # ensure that `NaN` is propagated correctly for complex numbers z = oftype(x, NaN) @@ -121,14 +124,13 @@ function _logsumexp_onepass_op(x::T, xmax::T, r::Number) where {T<:Number} else real_x = real(x) real_xmax = real(xmax) - if real_x > real_xmax + if isinf(real_x) && isinf(real_xmax) && (real_x * real_xmax) > 0 + # handle `x = xmax = ±Inf` correctly, without relying on ForwardDiff.value + xmax, r + exp(zero(x - xmax)) + elseif real_x > real_xmax x, (r + one(r)) * exp(xmax - x) - elseif real_x < real_xmax - xmax, r + exp(x - xmax) else - # handle `x = xmax = ±Inf` correctly - # checking inequalities above instead of equality fixes issue #59 - xmax, r + exp(zero(x - xmax)) + xmax, r + exp(x - xmax) end end return _xmax, _r diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..fdbcaac --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,13 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index b9665e7..27b247f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,3 +16,9 @@ include("basicfuns.jl") include("chainrules.jl") include("inverse.jl") include("with_logabsdet_jacobian.jl") + +# QA +import JET +JET.report_package("LogExpFunctions") +import Aqua +Aqua.test_all(LogExpFunctions)