Skip to content

Commit 023d370

Browse files
committed
Add DDE lowering
1 parent 207adaa commit 023d370

File tree

4 files changed

+139
-19
lines changed

4 files changed

+139
-19
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 130 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
120120
implicit_dae = false,
121121
ddvs = implicit_dae ? map(Differential(get_iv(sys)), dvs) :
122122
nothing,
123+
isdde = false,
123124
has_difference = false,
124125
kwargs...)
125-
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
126+
if isdde
127+
eqs = delay_to_function(sys)
128+
else
129+
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
130+
end
126131
if !implicit_dae
127132
check_operator_variables(eqs, Differential)
128133
check_lhs(eqs, Differential, Set(dvs))
@@ -136,15 +141,58 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
136141
p = map(x -> time_varying_as_func(value(x), sys), ps)
137142
t = get_iv(sys)
138143

139-
pre, sol_states = get_substitutions_and_solved_states(sys,
140-
no_postprocess = has_difference)
144+
if isdde
145+
build_function(rhss, u, DDE_HISTORY_FUN, p, t; kwargs...)
146+
else
147+
pre, sol_states = get_substitutions_and_solved_states(sys,
148+
no_postprocess = has_difference)
141149

142-
if implicit_dae
143-
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre, states = sol_states,
144-
kwargs...)
150+
if implicit_dae
151+
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre,
152+
states = sol_states,
153+
kwargs...)
154+
else
155+
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
156+
kwargs...)
157+
end
158+
end
159+
end
160+
161+
function isdelay(var, iv)
162+
isvariable(var) || return false
163+
if istree(var) && !ModelingToolkit.isoperator(var, Symbolics.Operator)
164+
args = arguments(var)
165+
length(args) == 1 || return false
166+
isequal(args[1], iv) || return true
167+
end
168+
return false
169+
end
170+
const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___)
171+
function delay_to_function(sys::AbstractODESystem)
172+
delay_to_function(full_equations(sys),
173+
get_iv(sys),
174+
Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(states(sys))),
175+
parameters(sys),
176+
DDE_HISTORY_FUN)
177+
end
178+
function delay_to_function(eqs::Vector{<:Equation}, iv, sts, ps, h)
179+
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,))
180+
end
181+
function delay_to_function(eq::Equation, iv, sts, ps, h)
182+
delay_to_function(eq.lhs, iv, sts, ps, h) ~ delay_to_function(eq.rhs, iv, sts, ps, h)
183+
end
184+
function delay_to_function(expr, iv, sts, ps, h)
185+
if isdelay(expr, iv)
186+
v = operation(expr)
187+
time = arguments(expr)[1]
188+
idx = sts[v]
189+
return term(getindex, h(Sym{Any}(:ˍ₋arg3), time), idx, type = Real) # BIG BIG HACK
190+
elseif istree(expr)
191+
return similarterm(expr,
192+
operation(expr),
193+
map(x -> delay_to_function(x, iv, sts, ps, h), arguments(expr)))
145194
else
146-
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
147-
kwargs...)
195+
return expr
148196
end
149197
end
150198

@@ -485,6 +533,30 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
485533
observed = observedfun)
486534
end
487535

536+
function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
537+
DDEFunction{true}(sys, args...; kwargs...)
538+
end
539+
540+
function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
541+
ps = parameters(sys), u0 = nothing;
542+
eval_module = @__MODULE__,
543+
checkbounds = false,
544+
kwargs...) where {iip}
545+
f_gen = generate_function(sys, dvs, ps; isdde = true,
546+
expression = Val{true},
547+
expression_module = eval_module, checkbounds = checkbounds,
548+
kwargs...)
549+
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
550+
f(u, p, h, t) = f_oop(u, p, h, t)
551+
f(du, u, p, h, t) = f_iip(du, u, p, h, t)
552+
553+
DDEFunction{iip}(f,
554+
sys = sys,
555+
syms = Symbol.(dvs),
556+
indepsym = Symbol(get_iv(sys)),
557+
paramsyms = Symbol.(ps))
558+
end
559+
488560
"""
489561
```julia
490562
ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
@@ -577,7 +649,7 @@ end
577649
"""
578650
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)
579651
580-
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
652+
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
581653
"""
582654
function get_u0_p(sys, u0map, parammap; use_union = false, tofloat = !use_union)
583655
eqs = equations(sys)
@@ -802,6 +874,55 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
802874
end
803875
end
804876

877+
function DiffEqBase.DDEProblem(sys::AbstractODESystem, args...; kwargs...)
878+
DDEProblem{true}(sys, args...; kwargs...)
879+
end
880+
function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
881+
h = (u, p) -> zeros(length(states(sts))),
882+
tspan = get_tspan(sys),
883+
parammap = DiffEqBase.NullParameters();
884+
callback = nothing,
885+
check_length = true,
886+
kwargs...) where {iip}
887+
has_difference = any(isdifferenceeq, equations(sys))
888+
f, u0, p = process_DEProblem(DDEFunction{iip}, sys, u0map, parammap;
889+
t = tspan !== nothing ? tspan[1] : tspan,
890+
has_difference = has_difference,
891+
check_length, kwargs...)
892+
cbs = process_events(sys; callback, has_difference, kwargs...)
893+
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
894+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
895+
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
896+
if clock isa Clock
897+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
898+
else
899+
error("$clock is not a supported clock type.")
900+
end
901+
end
902+
if cbs === nothing
903+
if length(discrete_cbs) == 1
904+
cbs = only(discrete_cbs)
905+
else
906+
cbs = CallbackSet(discrete_cbs...)
907+
end
908+
else
909+
cbs = CallbackSet(cbs, discrete_cbs)
910+
end
911+
else
912+
svs = nothing
913+
end
914+
kwargs = filter_kwargs(kwargs)
915+
916+
kwargs1 = (;)
917+
if cbs !== nothing
918+
kwargs1 = merge(kwargs1, (callback = cbs,))
919+
end
920+
if svs !== nothing
921+
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
922+
end
923+
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
924+
end
925+
805926
"""
806927
```julia
807928
ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,

src/systems/diffeqs/odesystem.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
181181
checks = true,
182182
metadata = nothing,
183183
gui_metadata = nothing)
184+
dvs = filter(x -> !isdelay(x, iv), dvs)
184185
name === nothing &&
185186
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
186187
deqs = scalarize(deqs)
@@ -258,6 +259,10 @@ function ODESystem(eqs, iv = nothing; kwargs...)
258259
push!(algeeq, eq)
259260
end
260261
end
262+
for v in allstates
263+
isdelay(v, iv) || continue
264+
collect_vars!(allstates, ps, arguments(v)[1], iv)
265+
end
261266
algevars = setdiff(allstates, diffvars)
262267
# the orders here are very important!
263268
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,

src/systems/systemstructure.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,7 @@ function TearingState(sys; quick_cancel = false, check = true)
288288
isalgeq = true
289289
statevars = []
290290
for var in vars
291-
if istree(var) && !ModelingToolkit.isoperator(var, Symbolics.Operator)
292-
args = arguments(var)
293-
length(args) == 1 || continue
294-
isequal(args[1], iv) || continue
295-
end
291+
ModelingToolkit.isdelay(var, iv) && continue
296292
set_incidence = true
297293
@label ANOTHER_VAR
298294
_var, _ = var_from_nested_derivative(var)

test/dde.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, DelayDiffEq
1+
using ModelingToolkit, DelayDiffEq, Test
22
p0 = 0.2;
33
q0 = 0.3;
44
v0 = 1;
@@ -33,12 +33,10 @@ eqs = [D(x₀) ~ (v0 / (1 + beta0 * (x₂(t - tau)^2))) * (p0 - q0) * x₀ - d0
3333
(v1 / (1 + beta1 * (x₂(t - tau)^2))) * (p1 - q1) * x₁ - d1 * x₁
3434
D(x₂(t)) ~ (v1 / (1 + beta1 * (x₂(t - tau)^2))) * (1 - p1 + q1) * x₁ - d2 * x₂(t)]
3535
@named sys = System(eqs)
36-
h(p, t) = ones(3)
37-
tspan = (0.0, 10.0)
3836
prob = DDEProblem(sys,
3937
[x₀ => 1.0, x₁ => 1.0, x₂(t) => 1.0],
4038
h,
4139
tspan,
4240
constant_lags = [tau])
43-
alg = MethodOfSteps(Tsit5())
44-
sol = solve(prob, alg)
41+
sol_mtk = solve(prob, alg)
42+
@test Array(sol_mtk) Array(sol)

0 commit comments

Comments
 (0)