diff --git a/src/common_defaults.jl b/src/common_defaults.jl index 0c43b50a2..f6d641f5e 100644 --- a/src/common_defaults.jl +++ b/src/common_defaults.jl @@ -30,3 +30,5 @@ Base.mapreduce_empty(::typeof(UNITLESS_ABS2), op, T) = abs2(Base.reduce_empty(op @inline ODE_DEFAULT_UNSTABLE_CHECK(dt,u,p,t) = false @inline ODE_DEFAULT_UNSTABLE_CHECK(dt,u::Union{Number,AbstractArray},p,t) = NAN_CHECK(u) + +@inline UNPERTURBED_NORM(u,t) = DiffEqBase.ODE_DEFAULT_NORM(u,t) diff --git a/src/init.jl b/src/init.jl index e738fd796..57de51e04 100644 --- a/src/init.jl +++ b/src/init.jl @@ -91,6 +91,20 @@ function __init__() get_tmp(dc::DiffCache, u::AbstractArray) = dc.du # bisection(f, tup::Tuple{T,T}, t_forward::Bool) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi()) + + # Support adaptive with non-dual time + @inline UNPERTURBED_NORM(u::AbstractArray{<:ForwardDiff.Dual},::Any) = sqrt(sum(DiffEqBase.UNITLESS_ABS2∘DiffEqBase.value,u) / length(u)) + @inline UNPERTURBED_NORM(u::ForwardDiff.Dual,::Any) = abs(DiffEqBase.value(u)) + + # When time is dual, it shouldn't drop the duals for adaptivity + @inline UNPERTURBED_NORM(u::AbstractArray{<:ForwardDiff.Dual},::ForwardDiff.Dual) = sqrt(sum(DiffEqBase.UNITLESS_ABS2,u) / length(u)) + @inline UNPERTURBED_NORM(u::ForwardDiff.Dual,::ForwardDiff.Dual) = abs(u) + + @inline UNPERTURBED_NORM(u::AbstractArray{<:ForwardDiff.Dual{<:Any,<:ForwardDiff.Dual}},::ForwardDiff.Dual) = sqrt(sum(DiffEqBase.UNITLESS_ABS2∘DiffEqBase.value,u) / length(u)) + @inline UNPERTURBED_NORM(u::ForwardDiff.Dual{<:Any,ForwardDiff.Dual},::ForwardDiff.Dual) = abs(DiffEqBase.value(u)) + + @inline UNPERTURBED_NORM(u::AbstractArray{<:ForwardDiff.Dual{<:Any,<:ForwardDiff.Dual}},::ForwardDiff.Dual{<:Any,ForwardDiff.Dual}) = sqrt(sum(DiffEqBase.UNITLESS_ABS2,u) / length(u)) + @inline UNPERTURBED_NORM(u::ForwardDiff.Dual{<:Any,ForwardDiff.Dual},::ForwardDiff.Dual{<:Any,ForwardDiff.Dual}) = abs(u) end @require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin