Skip to content

Commit 008071d

Browse files
committed
SDDE support
1 parent e63aad0 commit 008071d

File tree

3 files changed

+145
-10
lines changed

3 files changed

+145
-10
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,14 @@ function isdelay(var, iv)
169169
return false
170170
end
171171
const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___)
172-
function delay_to_function(sys::AbstractODESystem)
173-
delay_to_function(full_equations(sys),
172+
function delay_to_function(sys::AbstractODESystem, eqs = full_equations(sys))
173+
delay_to_function(eqs,
174174
get_iv(sys),
175175
Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(states(sys))),
176176
parameters(sys),
177177
DDE_HISTORY_FUN)
178178
end
179-
function delay_to_function(eqs::Vector{<:Equation}, iv, sts, ps, h)
179+
function delay_to_function(eqs::Vector, iv, sts, ps, h)
180180
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,))
181181
end
182182
function delay_to_function(eq::Equation, iv, sts, ps, h)
@@ -548,8 +548,8 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
548548
expression_module = eval_module, checkbounds = checkbounds,
549549
kwargs...)
550550
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
551-
f(u, p, h, t) = f_oop(u, p, h, t)
552-
f(du, u, p, h, t) = f_iip(du, u, p, h, t)
551+
f(u, h, p, t) = f_oop(u, h, p, t)
552+
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
553553

554554
DDEFunction{iip}(f,
555555
sys = sys,
@@ -558,6 +558,36 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
558558
paramsyms = Symbol.(ps))
559559
end
560560

561+
function DiffEqBase.SDDEFunction(sys::AbstractODESystem, args...; kwargs...)
562+
SDDEFunction{true}(sys, args...; kwargs...)
563+
end
564+
565+
function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
566+
ps = parameters(sys), u0 = nothing;
567+
eval_module = @__MODULE__,
568+
checkbounds = false,
569+
kwargs...) where {iip}
570+
f_gen = generate_function(sys, dvs, ps; isdde = true,
571+
expression = Val{true},
572+
expression_module = eval_module, checkbounds = checkbounds,
573+
kwargs...)
574+
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
575+
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
576+
isdde = true, kwargs...)
577+
@show g_gen[2]
578+
g_oop, g_iip = (drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen)
579+
f(u, h, p, t) = f_oop(u, h, p, t)
580+
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
581+
g(u, h, p, t) = g_oop(u, h, p, t)
582+
g(du, u, h, p, t) = g_iip(du, u, h, p, t)
583+
584+
SDDEFunction{iip}(f, g,
585+
sys = sys,
586+
syms = Symbol.(dvs),
587+
indepsym = Symbol(get_iv(sys)),
588+
paramsyms = Symbol.(ps))
589+
end
590+
561591
"""
562592
```julia
563593
ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
@@ -941,6 +971,72 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
941971
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
942972
end
943973

974+
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
975+
SDDEProblem{true}(sys, args...; kwargs...)
976+
end
977+
function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
978+
tspan = get_tspan(sys),
979+
parammap = DiffEqBase.NullParameters();
980+
callback = nothing,
981+
check_length = true,
982+
sparsenoise = nothing,
983+
kwargs...) where {iip}
984+
has_difference = any(isdifferenceeq, equations(sys))
985+
f, u0, p = process_DEProblem(SDDEFunction{iip}, sys, u0map, parammap;
986+
t = tspan !== nothing ? tspan[1] : tspan,
987+
has_difference = has_difference,
988+
symbolic_u0 = true,
989+
check_length, kwargs...)
990+
h_oop, h_iip = generate_history(sys, u0)
991+
h(out, p, t) = h_iip(out, p, t)
992+
h(p, t) = h_oop(p, t)
993+
u0 = h(p, tspan[1])
994+
cbs = process_events(sys; callback, has_difference, kwargs...)
995+
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
996+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
997+
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
998+
if clock isa Clock
999+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
1000+
else
1001+
error("$clock is not a supported clock type.")
1002+
end
1003+
end
1004+
if cbs === nothing
1005+
if length(discrete_cbs) == 1
1006+
cbs = only(discrete_cbs)
1007+
else
1008+
cbs = CallbackSet(discrete_cbs...)
1009+
end
1010+
else
1011+
cbs = CallbackSet(cbs, discrete_cbs)
1012+
end
1013+
else
1014+
svs = nothing
1015+
end
1016+
kwargs = filter_kwargs(kwargs)
1017+
1018+
kwargs1 = (;)
1019+
if cbs !== nothing
1020+
kwargs1 = merge(kwargs1, (callback = cbs,))
1021+
end
1022+
if svs !== nothing
1023+
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
1024+
end
1025+
1026+
noiseeqs = get_noiseeqs(sys)
1027+
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
1028+
if noiseeqs isa AbstractVector
1029+
noise_rate_prototype = nothing
1030+
elseif sparsenoise
1031+
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
1032+
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
1033+
else
1034+
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
1035+
end
1036+
SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype =
1037+
noise_rate_prototype, kwargs1..., kwargs...)
1038+
end
1039+
9441040
"""
9451041
```julia
9461042
ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,

src/systems/diffeqs/sdesystem.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,18 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
210210
end
211211

212212
function generate_diffusion_function(sys::SDESystem, dvs = states(sys),
213-
ps = parameters(sys); kwargs...)
214-
return build_function(get_noiseeqs(sys),
215-
map(x -> time_varying_as_func(value(x), sys), dvs),
216-
map(x -> time_varying_as_func(value(x), sys), ps),
217-
get_iv(sys); kwargs...)
213+
ps = parameters(sys); isdde = false, kwargs...)
214+
eqs = get_noiseeqs(sys)
215+
if isdde
216+
eqs = delay_to_function(sys, eqs)
217+
end
218+
u = map(x -> time_varying_as_func(value(x), sys), dvs)
219+
p = map(x -> time_varying_as_func(value(x), sys), ps)
220+
if isdde
221+
return build_function(eqs, u, DDE_HISTORY_FUN, p, get_iv(sys); kwargs...)
222+
else
223+
return build_function(eqs, u, p, get_iv(sys); kwargs...)
224+
end
218225
end
219226

220227
"""

test/dde.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,35 @@ prob2 = DDEProblem(sys,
4949
constant_lags = [tau])
5050
sol2_mtk = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10)
5151
@test sol2_mtk.u[end] sol2.u[end]
52+
53+
using StochasticDelayDiffEq
54+
function hayes_modelf(du, u, h, p, t)
55+
τ, a, b, c, α, β, γ = p
56+
du .= a .* u .+ b .* h(p, t - τ) .+ c
57+
end
58+
function hayes_modelg(du, u, h, p, t)
59+
τ, a, b, c, α, β, γ = p
60+
du .= α .* u .+ γ
61+
end
62+
h(p, t) = (ones(1) .+ t);
63+
tspan = (0.0, 10.0)
64+
65+
pmul = [1.0,
66+
-4.0, -2.0, 10.0,
67+
-1.3, -1.2, 1.1]
68+
69+
prob = SDDEProblem(hayes_modelf, hayes_modelg, [1.0], h, tspan, pmul;
70+
constant_lags = (pmul[1],));
71+
sol = solve(prob, RKMil())
72+
73+
@variables t x(..)
74+
@parameters a=-4.0 b=-2.0 c=10.0 α=-1.3 β=-1.2 γ=1.1
75+
D = Differential(t)
76+
@brownian η
77+
τ = 1.0
78+
eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c +* x(t) + γ) * η]
79+
@named sys = System(eqs)
80+
sys = structural_simplify(sys)
81+
prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,));
82+
sol_mtk = solve(prob_mtk, RKMil())
83+
@test sol.u[end] sol_mtk.u[end]

0 commit comments

Comments
 (0)