diff --git a/Project.toml b/Project.toml index 5b0200e67..6d38d4e40 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqBase" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" authors = ["Chris Rackauckas "] -version = "6.162.1" +version = "6.162.2" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/DiffEqBaseForwardDiffExt.jl b/ext/DiffEqBaseForwardDiffExt.jl index 661152000..adfb044de 100644 --- a/ext/DiffEqBaseForwardDiffExt.jl +++ b/ext/DiffEqBaseForwardDiffExt.jl @@ -501,19 +501,19 @@ unitfulvalue(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V unitfulvalue(x::ForwardDiff.Dual) = unitfulvalue(ForwardDiff.unitfulvalue(x)) sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x)) -function totallength(x::ForwardDiff.Dual) - totallength(ForwardDiff.value(x)) + sum(totallength, ForwardDiff.partials(x)) +function DiffEqBase.totallength(x::ForwardDiff.Dual) + return DiffEqBase.totallength(ForwardDiff.value(x)) + sum(DiffEqBase.totallength, ForwardDiff.partials(x)) end @inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::Any) = sqrt(sse(u)) @inline function ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{Tag, T}}, t::Any) where {Tag, T} - sqrt(__sum(sse, u; init = sse(zero(T))) / totallength(u)) + sqrt(DiffEqBase.__sum(sse, u; init = sse(zero(T))) / DiffEqBase.totallength(u)) end @inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::ForwardDiff.Dual) = sqrt(sse(u)) @inline function ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{Tag, T}}, ::ForwardDiff.Dual) where {Tag, T} - sqrt(__sum(sse, u; init = sse(zero(T))) / totallength(u)) + sqrt(DiffEqBase.__sum(sse, u; init = sse(zero(T))) / DiffEqBase.totallength(u)) end if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual}) @@ -528,13 +528,6 @@ end # bisection(f, tup::Tuple{T,T}, t_forward::Bool) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi()) -# Static Arrays don't support the `init` keyword argument for `sum` -@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...) -@inline function __sum( - f::F, a::DiffEqBase.StaticArraysCore.StaticArray...; init, kwargs...) where {F} - return mapreduce(f, +, a...; init, kwargs...) -end - # Differentiation of internal solver function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...) diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index ed2be1ba9..872d519a1 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -4,7 +4,6 @@ using DiffEqBase import DiffEqBase: value import ReverseDiff import DiffEqBase.ArrayInterface -import DiffEqBase.ForwardDiff function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where { diff --git a/src/utils.jl b/src/utils.jl index e4a4b9301..8d941c127 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,6 +4,13 @@ unitfulvalue(x) = x isdistribution(u0) = false sse(x::Number) = abs2(x) +# Static Arrays don't support the `init` keyword argument for `sum` +@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...) +@inline function __sum( + f::F, a::StaticArraysCore.StaticArray...; init, kwargs...) where {F} + return mapreduce(f, +, a...; init, kwargs...) +end + totallength(x::Number) = 1 totallength(x::AbstractArray) = __sum(totallength, x; init = 0) diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index 292c48d9e..3685164ea 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -379,3 +379,13 @@ end ReverseDiff.TrackedReal{<:ForwardDiff.Dual} @test DiffEqBase.promote_u0(NaN, [NaN], 0.0) isa Float64 @test DiffEqBase.promote_u0([1.0], [NaN], 0.0) isa Vector{Float64} + +# totallength +val = rand(10) +par = rand(10) +u = Dual.(val, par) +@test DiffEqBase.totallength(val[1]) == 1 +@test DiffEqBase.totallength(val) == length(val) +@test DiffEqBase.totallength(par) == length(par) +@test DiffEqBase.totallength(u[1]) == DiffEqBase.totallength(val[1]) + DiffEqBase.totallength(par[1]) +@test DiffEqBase.totallength(u) == sum(DiffEqBase.totallength, u)