Skip to content

Commit 24d0d7c

Browse files
authored
Merge pull request #2208 from SciML/myb/sdde
SDDE support
2 parents e63aad0 + 46c8a2f commit 24d0d7c

File tree

4 files changed

+147
-11
lines changed

4 files changed

+147
-11
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,10 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
114114
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
115115
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
116116
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
117+
StochasticDelayDiffEq = "29a0d76e-afc8-11e9-03a4-eda52ae4b960"
117118
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
118119
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
119120
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
120121

121122
[targets]
122-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
123+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq"]

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,14 @@ function isdelay(var, iv)
169169
return false
170170
end
171171
const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___)
172-
function delay_to_function(sys::AbstractODESystem)
173-
delay_to_function(full_equations(sys),
172+
function delay_to_function(sys::AbstractODESystem, eqs = full_equations(sys))
173+
delay_to_function(eqs,
174174
get_iv(sys),
175175
Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(states(sys))),
176176
parameters(sys),
177177
DDE_HISTORY_FUN)
178178
end
179-
function delay_to_function(eqs::Vector{<:Equation}, iv, sts, ps, h)
179+
function delay_to_function(eqs::Vector, iv, sts, ps, h)
180180
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,))
181181
end
182182
function delay_to_function(eq::Equation, iv, sts, ps, h)
@@ -548,8 +548,8 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
548548
expression_module = eval_module, checkbounds = checkbounds,
549549
kwargs...)
550550
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
551-
f(u, p, h, t) = f_oop(u, p, h, t)
552-
f(du, u, p, h, t) = f_iip(du, u, p, h, t)
551+
f(u, h, p, t) = f_oop(u, h, p, t)
552+
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
553553

554554
DDEFunction{iip}(f,
555555
sys = sys,
@@ -558,6 +558,35 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
558558
paramsyms = Symbol.(ps))
559559
end
560560

561+
function DiffEqBase.SDDEFunction(sys::AbstractODESystem, args...; kwargs...)
562+
SDDEFunction{true}(sys, args...; kwargs...)
563+
end
564+
565+
function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
566+
ps = parameters(sys), u0 = nothing;
567+
eval_module = @__MODULE__,
568+
checkbounds = false,
569+
kwargs...) where {iip}
570+
f_gen = generate_function(sys, dvs, ps; isdde = true,
571+
expression = Val{true},
572+
expression_module = eval_module, checkbounds = checkbounds,
573+
kwargs...)
574+
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
575+
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
576+
isdde = true, kwargs...)
577+
g_oop, g_iip = (drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen)
578+
f(u, h, p, t) = f_oop(u, h, p, t)
579+
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
580+
g(u, h, p, t) = g_oop(u, h, p, t)
581+
g(du, u, h, p, t) = g_iip(du, u, h, p, t)
582+
583+
SDDEFunction{iip}(f, g,
584+
sys = sys,
585+
syms = Symbol.(dvs),
586+
indepsym = Symbol(get_iv(sys)),
587+
paramsyms = Symbol.(ps))
588+
end
589+
561590
"""
562591
```julia
563592
ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
@@ -941,6 +970,72 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
941970
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
942971
end
943972

973+
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
974+
SDDEProblem{true}(sys, args...; kwargs...)
975+
end
976+
function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
977+
tspan = get_tspan(sys),
978+
parammap = DiffEqBase.NullParameters();
979+
callback = nothing,
980+
check_length = true,
981+
sparsenoise = nothing,
982+
kwargs...) where {iip}
983+
has_difference = any(isdifferenceeq, equations(sys))
984+
f, u0, p = process_DEProblem(SDDEFunction{iip}, sys, u0map, parammap;
985+
t = tspan !== nothing ? tspan[1] : tspan,
986+
has_difference = has_difference,
987+
symbolic_u0 = true,
988+
check_length, kwargs...)
989+
h_oop, h_iip = generate_history(sys, u0)
990+
h(out, p, t) = h_iip(out, p, t)
991+
h(p, t) = h_oop(p, t)
992+
u0 = h(p, tspan[1])
993+
cbs = process_events(sys; callback, has_difference, kwargs...)
994+
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
995+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
996+
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
997+
if clock isa Clock
998+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
999+
else
1000+
error("$clock is not a supported clock type.")
1001+
end
1002+
end
1003+
if cbs === nothing
1004+
if length(discrete_cbs) == 1
1005+
cbs = only(discrete_cbs)
1006+
else
1007+
cbs = CallbackSet(discrete_cbs...)
1008+
end
1009+
else
1010+
cbs = CallbackSet(cbs, discrete_cbs)
1011+
end
1012+
else
1013+
svs = nothing
1014+
end
1015+
kwargs = filter_kwargs(kwargs)
1016+
1017+
kwargs1 = (;)
1018+
if cbs !== nothing
1019+
kwargs1 = merge(kwargs1, (callback = cbs,))
1020+
end
1021+
if svs !== nothing
1022+
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
1023+
end
1024+
1025+
noiseeqs = get_noiseeqs(sys)
1026+
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
1027+
if noiseeqs isa AbstractVector
1028+
noise_rate_prototype = nothing
1029+
elseif sparsenoise
1030+
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
1031+
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
1032+
else
1033+
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
1034+
end
1035+
SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype =
1036+
noise_rate_prototype, kwargs1..., kwargs...)
1037+
end
1038+
9441039
"""
9451040
```julia
9461041
ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,

src/systems/diffeqs/sdesystem.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,18 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
210210
end
211211

212212
function generate_diffusion_function(sys::SDESystem, dvs = states(sys),
213-
ps = parameters(sys); kwargs...)
214-
return build_function(get_noiseeqs(sys),
215-
map(x -> time_varying_as_func(value(x), sys), dvs),
216-
map(x -> time_varying_as_func(value(x), sys), ps),
217-
get_iv(sys); kwargs...)
213+
ps = parameters(sys); isdde = false, kwargs...)
214+
eqs = get_noiseeqs(sys)
215+
if isdde
216+
eqs = delay_to_function(sys, eqs)
217+
end
218+
u = map(x -> time_varying_as_func(value(x), sys), dvs)
219+
p = map(x -> time_varying_as_func(value(x), sys), ps)
220+
if isdde
221+
return build_function(eqs, u, DDE_HISTORY_FUN, p, get_iv(sys); kwargs...)
222+
else
223+
return build_function(eqs, u, p, get_iv(sys); kwargs...)
224+
end
218225
end
219226

220227
"""

test/dde.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,36 @@ prob2 = DDEProblem(sys,
4949
constant_lags = [tau])
5050
sol2_mtk = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10)
5151
@test sol2_mtk.u[end] sol2.u[end]
52+
53+
using StochasticDelayDiffEq
54+
function hayes_modelf(du, u, h, p, t)
55+
τ, a, b, c, α, β, γ = p
56+
du .= a .* u .+ b .* h(p, t - τ) .+ c
57+
end
58+
function hayes_modelg(du, u, h, p, t)
59+
τ, a, b, c, α, β, γ = p
60+
du .= α .* u .+ γ
61+
end
62+
h(p, t) = (ones(1) .+ t);
63+
tspan = (0.0, 10.0)
64+
65+
pmul = [1.0,
66+
-4.0, -2.0, 10.0,
67+
-1.3, -1.2, 1.1]
68+
69+
prob = SDDEProblem(hayes_modelf, hayes_modelg, [1.0], h, tspan, pmul;
70+
constant_lags = (pmul[1],));
71+
sol = solve(prob, RKMil())
72+
73+
@variables t x(..)
74+
@parameters a=-4.0 b=-2.0 c=10.0 α=-1.3 β=-1.2 γ=1.1
75+
D = Differential(t)
76+
@brownian η
77+
τ = 1.0
78+
eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c +* x(t) + γ) * η]
79+
@named sys = System(eqs)
80+
sys = structural_simplify(sys)
81+
@test equations(sys) == [D(x(t)) ~ a * x(t) + b * x(t - τ) + c]
82+
@test isequal(ModelingToolkit.get_noiseeqs(sys), [α * x(t) + γ;;])
83+
prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,));
84+
@test_nowarn sol_mtk = solve(prob_mtk, RKMil())

0 commit comments

Comments
 (0)