Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[compat]
ArrayInterface = "7"
Expand All @@ -37,4 +38,5 @@ Reexport = "0.2, 1.0"
SciMLBase = "2.59.2"
SimpleNonlinearSolve = "0.1, 1, 2"
SimpleUnPack = "1"
SymbolicIndexingInterface = "0.3.36"
julia = "1.9"
1 change: 1 addition & 0 deletions src/DelayDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using SimpleUnPack

import ArrayInterface
import SimpleNonlinearSolve
import SymbolicIndexingInterface as SII

using DiffEqBase: AbstractDDEAlgorithm, AbstractDDEIntegrator, AbstractODEIntegrator,
DEIntegrator, AbstractDDEProblem
Expand Down
9 changes: 6 additions & 3 deletions src/functionwrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ macro wrap_h(signature)
end |> esc
end

struct ODEFunctionWrapper{iip, F, H, TMM, Ta, Tt, TJ, JP, SP, TW, TWt, TPJ, S, TCV} <:
struct ODEFunctionWrapper{iip, F, H, TMM, Ta, Tt, TJ, JP, SP, TW, TWt, TPJ, S, TCV, ID} <:
DiffEqBase.AbstractODEFunction{iip}
f::F
h::H
Expand All @@ -39,6 +39,7 @@ struct ODEFunctionWrapper{iip, F, H, TMM, Ta, Tt, TJ, JP, SP, TW, TWt, TPJ, S, T
paramjac::TPJ
sys::S
colorvec::TCV
initialization_data::ID
end

function ODEFunctionWrapper(f::DiffEqBase.AbstractDDEFunction, h)
Expand All @@ -51,7 +52,8 @@ function ODEFunctionWrapper(f::DiffEqBase.AbstractDDEFunction, h)
typeof(f.analytic), typeof(f.tgrad), typeof(jac),
typeof(f.jac_prototype), typeof(f.sparsity),
typeof(Wfact), typeof(Wfact_t),
typeof(f.paramjac), typeof(f.sys), typeof(f.colorvec)}(f.f, h,
typeof(f.paramjac), typeof(f.sys), typeof(f.colorvec),
typeof(f.initialization_data)}(f.f, h,
f.mass_matrix,
f.analytic,
f.tgrad, jac,
Expand All @@ -61,7 +63,8 @@ function ODEFunctionWrapper(f::DiffEqBase.AbstractDDEFunction, h)
Wfact_t,
f.paramjac,
f.sys,
f.colorvec)
f.colorvec,
f.initialization_data)
end

(f::ODEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t)
Expand Down
7 changes: 6 additions & 1 deletion src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn
ksEltype, SolType, F, CacheType, IType, FP, O, dAbsType,
dRelType, H,
tstopsType, discType, FSALType, EventErrorType,
CallbackCacheType, DV} <:
CallbackCacheType, DV, IA} <:
AbstractDDEIntegrator{algType, IIP, uType, tType}
sol::SolType
u::uType
Expand Down Expand Up @@ -95,6 +95,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn
integrator::IType
fsalfirst::FSALType
fsallast::FSALType
initializealg::IA
end

function (integrator::DDEIntegrator)(t, deriv::Type = Val{0}; idxs = nothing)
Expand All @@ -105,3 +106,7 @@ function (integrator::DDEIntegrator)(val::AbstractArray, t::Union{Number, Abstra
deriv::Type = Val{0}; idxs = nothing)
OrdinaryDiffEq.current_interpolant!(val, t, integrator, idxs, deriv)
end

function SII.get_history_function(integrator::DDEIntegrator)
return integrator.history
end
21 changes: 19 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem,
discontinuity_interp_points::Int = 10,
discontinuity_abstol = eltype(prob.tspan)(1 // Int64(10)^12),
discontinuity_reltol = 0,
initializealg = DDEDefaultInit(),
kwargs...)
if haskey(kwargs, :initial_order)
@warn "initial_order has been deprecated. Please specify order_discontinuity_t0 in the DDEProblem instead."
Expand Down Expand Up @@ -350,7 +351,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem,
typeof(d_discontinuities_propagated),
typeof(fsalfirst),
typeof(last_event_error), typeof(callback_cache),
typeof(differential_vars)}(sol, u, k,
typeof(differential_vars), typeof(initializealg)}(sol, u, k,
t0,
tType(dt),
f_with_history,
Expand Down Expand Up @@ -402,10 +403,11 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem,
stats,
history,
differential_vars,
ode_integrator, fsalfirst, fsallast)
ode_integrator, fsalfirst, fsallast, initializealg)

# initialize DDE integrator
if initialize_integrator
DiffEqBase.initialize_dae!(integrator)
initialize_solution!(integrator)
OrdinaryDiffEqCore.initialize_callbacks!(integrator, initialize_save)
OrdinaryDiffEqCore.initialize!(integrator)
Expand Down Expand Up @@ -538,3 +540,18 @@ function initialize_tstops_d_discontinuities_propagated(::Type{T}, tstops,

return tstops_propagated, d_discontinuities_propagated
end

struct DDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end

function DiffEqBase.initialize_dae!(integrator::DDEIntegrator, initializealg = integrator.initializealg)
OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg,
Val(DiffEqBase.isinplace(integrator.sol.prob)))
end

function OrdinaryDiffEqCore._initialize_dae!(integrator::DDEIntegrator, prob, ::DDEDefaultInit, isinplace)
if SciMLBase.has_initializeprob(prob.f)
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace)
else
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace)
end
end
38 changes: 38 additions & 0 deletions test/integrators/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using DelayDiffEq
using SciMLBase
using LinearAlgebra
using Test

@testset "CheckInit" begin
u0_good = [0.99, 0.01, 0.0]
sir_history(p, t) = [1.0, 0.0, 0.0]
tspan = (0.0, 40.0)
p = (γ = 0.5, τ = 4.0)

function sir_ddae!(du, u, h, p, t)
S, I, R = u
γ, τ = p
infection = γ * I * S
Sd, Id, _ = h(p, t - τ)
recovery = γ * Id * Sd
@inbounds begin
du[1] = -infection
du[2] = infection - recovery
du[3] = S + I + R - 1
end
nothing
end

prob_ddae = DDEProblem(
DDEFunction{true}(sir_ddae!;
mass_matrix = Diagonal([1.0, 1.0, 0.0])),
u0,
sir_history,
tspan,
p;
constant_lags = (p.τ,))
alg = MethodOfSteps(Rosenbrock23())
@test_nowarn init(prob_ddae, alg)
prob.u0[1] = 2.0
@test_throws SciMLBase.CheckInitFailureError init(prob_ddae, alg)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ if GROUP == "All" || GROUP == "Integrators"
@time @safetestset "Verner Tests" begin
include("integrators/verner.jl")
end
@time @safetestset "Initialization" begin
include("integrators/initialization.jl")
end
end

if GROUP == "All" || GROUP == "Regression"
Expand Down
Loading