Skip to content

Commit 549885d

Browse files
authored
Merge pull request #1126 from sharanry/sy/difference_ODESys
Enable usage of difference operator in ODESystem
2 parents 2be7685 + 477d34c commit 549885d

File tree

7 files changed

+190
-85
lines changed

7 files changed

+190
-85
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1111
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
12+
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
1213
DiffEqJump = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12"
1314
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1415
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -48,6 +49,7 @@ ArrayInterface = "2.8, 3.0"
4849
ConstructionBase = "1"
4950
DataStructures = "0.17, 0.18"
5051
DiffEqBase = "6.54.0"
52+
DiffEqCallbacks = "2.16"
5153
DiffEqJump = "6.7.5"
5254
DiffRules = "0.1, 1.0"
5355
Distributions = "0.23, 0.24, 0.25"

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using DataStructures
1717
using SpecialFunctions, NaNMath
1818
using RuntimeGeneratedFunctions
1919
using Base.Threads
20+
using DiffEqCallbacks
2021
import MacroTools: splitdef, combinedef, postwalk, striplines
2122
import Libdl
2223
using DocStringExtensions
@@ -182,6 +183,7 @@ export calculate_hessian, generate_hessian
182183
export calculate_massmatrix, generate_diffusion_function
183184
export stochastic_integral_transform
184185
export initialize_system_structure
186+
export generate_difference_cb
185187

186188
export BipartiteGraph, equation_dependencies, variable_dependencies
187189
export eqeq_dependencies, varvar_dependencies

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function generate_function(
9090
#obsvars = map(eq->eq.lhs, observed(sys))
9191
#fulldvs = [dvs; obsvars]
9292

93-
eqs = equations(sys)
93+
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
9494
foreach(check_derivative_variables, eqs)
9595
# substitute x(t) by just x
9696
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
@@ -109,6 +109,46 @@ function generate_function(
109109
end
110110
end
111111

112+
@inline function allequal(x)
113+
length(x) < 2 && return true
114+
e1 = first(x)
115+
i = 2
116+
@inbounds for i=2:length(x)
117+
x[i] == e1 || return false
118+
end
119+
return true
120+
end
121+
122+
function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = parameters(sys);
123+
kwargs...)
124+
eqs = equations(sys)
125+
foreach(check_difference_variables, eqs)
126+
127+
rhss = [
128+
begin
129+
ind = findfirst(eq -> isdifference(eq.lhs) && isequal(arguments(eq.lhs)[1], s), eqs)
130+
ind === nothing ? 0 : eqs[ind].rhs
131+
end
132+
for s in dvs ]
133+
134+
u = map(x->time_varying_as_func(value(x), sys), dvs)
135+
p = map(x->time_varying_as_func(value(x), sys), ps)
136+
t = get_iv(sys)
137+
138+
f_oop, f_iip = build_function(rhss, u, p, t; kwargs...)
139+
140+
f = @RuntimeGeneratedFunction(@__MODULE__, f_oop)
141+
142+
function cb_affect!(int)
143+
int.u += f(int.u, int.p, int.t)
144+
end
145+
146+
dts = [ operation(eq.lhs).dt for eq in eqs if isdifferenceeq(eq)]
147+
allequal(dts) || error("All difference variables should have same time steps.")
148+
149+
PeriodicCallback(cb_affect!, first(dts))
150+
end
151+
112152
function time_varying_as_func(x, sys)
113153
# if something is not x(t) (the current state)
114154
# but is `x(t-1)` or something like that, pass in `x` as a callable function rather
@@ -124,7 +164,7 @@ function time_varying_as_func(x, sys)
124164
end
125165

126166
function calculate_massmatrix(sys::AbstractODESystem; simplify=false)
127-
eqs = equations(sys)
167+
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
128168
dvs = states(sys)
129169
M = zeros(length(eqs),length(eqs))
130170
state2idx = Dict(s => i for (i, s) in enumerate(dvs))
@@ -536,7 +576,12 @@ symbolically calculating numerical enhancements.
536576
function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
537577
parammap=DiffEqBase.NullParameters();kwargs...) where iip
538578
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; kwargs...)
539-
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
579+
if any(isdifferenceeq.(equations(sys)))
580+
ODEProblem{iip}(f,u0,tspan,p;difference_cb=generate_difference_cb(sys),kwargs...)
581+
else
582+
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
583+
end
584+
540585
end
541586

542587
"""
@@ -563,7 +608,12 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,du0map,u0map,tspan,
563608
diffvars = collect_differential_variables(sys)
564609
sts = states(sys)
565610
differential_vars = map(Base.Fix2(in, diffvars), sts)
566-
DAEProblem{iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,kwargs...)
611+
if any(isdifferenceeq.(equations(sys)))
612+
DAEProblem{iip}(f,du0,u0,tspan,p;difference_cb=generate_difference_cb(sys),differential_vars=differential_vars,kwargs...)
613+
else
614+
DAEProblem{iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,kwargs...)
615+
end
616+
567617
end
568618

569619
"""
@@ -713,6 +763,3 @@ end
713763
function SteadyStateProblemExpr(sys::AbstractODESystem, args...; kwargs...)
714764
SteadyStateProblemExpr{true}(sys, args...; kwargs...)
715765
end
716-
717-
isdifferential(expr) = istree(expr) && operation(expr) isa Differential
718-
isdiffeq(eq) = isdifferential(eq.lhs)

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -134,42 +134,6 @@ function ODESystem(
134134
ODESystem(deqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
135135
end
136136

137-
vars(x::Sym) = Set([x])
138-
vars(exprs::Symbolic) = vars([exprs])
139-
vars(exprs) = foldl(vars!, exprs; init = Set())
140-
vars!(vars, eq::Equation) = (vars!(vars, eq.lhs); vars!(vars, eq.rhs); vars)
141-
function vars!(vars, O)
142-
if isa(O, Sym)
143-
return push!(vars, O)
144-
end
145-
!istree(O) && return vars
146-
147-
operation(O) isa Differential && return push!(vars, O)
148-
149-
if operation(O) === (getindex) &&
150-
first(arguments(O)) isa Symbolic
151-
152-
return push!(vars, O)
153-
end
154-
155-
symtype(operation(O)) <: FnType && push!(vars, O)
156-
for arg in arguments(O)
157-
vars!(vars, arg)
158-
end
159-
160-
return vars
161-
end
162-
163-
find_derivatives!(vars, expr::Equation, f=identity) = (find_derivatives!(vars, expr.lhs, f); find_derivatives!(vars, expr.rhs, f); vars)
164-
function find_derivatives!(vars, expr, f)
165-
!istree(O) && return vars
166-
operation(O) isa Differential && push!(vars, f(O))
167-
for arg in arguments(O)
168-
vars!(vars, arg)
169-
end
170-
return vars
171-
end
172-
173137
function ODESystem(eqs, iv=nothing; kwargs...)
174138
eqs = collect(eqs)
175139
# NOTE: this assumes that the order of algebric equations doesn't matter
@@ -208,30 +172,6 @@ function ODESystem(eqs, iv=nothing; kwargs...)
208172
return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...)
209173
end
210174

211-
function collect_vars!(states, parameters, expr, iv)
212-
if expr isa Sym
213-
collect_var!(states, parameters, expr, iv)
214-
else
215-
for var in vars(expr)
216-
if istree(var) && operation(var) isa Differential
217-
var, _ = var_from_nested_derivative(var)
218-
end
219-
collect_var!(states, parameters, var, iv)
220-
end
221-
end
222-
return nothing
223-
end
224-
225-
function collect_var!(states, parameters, var, iv)
226-
isequal(var, iv) && return nothing
227-
if isparameter(var) || (istree(var) && isparameter(operation(var)))
228-
push!(parameters, var)
229-
else
230-
push!(states, var)
231-
end
232-
return nothing
233-
end
234-
235175
# NOTE: equality does not check cached Jacobian
236176
function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
237177
iv1 = independent_variable(sys1)
@@ -319,21 +259,6 @@ function _eq_unordered(a, b)
319259
return true
320260
end
321261

322-
function collect_differential_variables(sys::ODESystem)
323-
eqs = equations(sys)
324-
vars = Set()
325-
diffvars = Set()
326-
for eq in eqs
327-
vars!(vars, eq)
328-
for v in vars
329-
isdifferential(v) || continue
330-
push!(diffvars, arguments(v)[1])
331-
end
332-
empty!(vars)
333-
end
334-
return diffvars
335-
end
336-
337262
# We have a stand-alone function to convert a `NonlinearSystem` or `ODESystem`
338263
# to an `ODESystem` to connect systems, and we later can reply on
339264
# `structural_simplify` to convert `ODESystem`s to `NonlinearSystem`s.

src/systems/discrete_system/discrete_system.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
122122
DiscreteProblem(f,u0,tspan,p;kwargs...)
123123
end
124124

125-
isdifference(expr) = istree(expr) && operation(expr) isa Difference
126-
isdifferenceeq(eq) = isdifference(eq.lhs)
127-
128125
check_difference_variables(eq) = check_operator_variables(eq, Difference)
129126

130127
function generate_function(

src/utils.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,93 @@ function check_operator_variables(eq, op::Type, expr=eq.rhs)
187187
end
188188
foreach(expr -> check_operator_variables(eq, op, expr), arguments(expr))
189189
end
190+
191+
isdifferential(expr) = istree(expr) && operation(expr) isa Differential
192+
isdiffeq(eq) = isdifferential(eq.lhs)
193+
194+
isdifference(expr) = istree(expr) && operation(expr) isa Difference
195+
isdifferenceeq(eq) = isdifference(eq.lhs)
196+
197+
vars(x::Sym; op=Differential) = Set([x])
198+
vars(exprs::Symbolic; op=Differential) = vars([exprs]; op=op)
199+
vars(exprs; op=Differential) = foldl((x, y) -> vars!(x, y; op=op), exprs; init = Set())
200+
vars!(vars, eq::Equation; op=Differential) = (vars!(vars, eq.lhs; op=op); vars!(vars, eq.rhs; op=op); vars)
201+
function vars!(vars, O; op=Differential)
202+
if isa(O, Sym)
203+
return push!(vars, O)
204+
end
205+
!istree(O) && return vars
206+
207+
operation(O) isa op && return push!(vars, O)
208+
209+
if operation(O) === (getindex) &&
210+
first(arguments(O)) isa Symbolic
211+
212+
return push!(vars, O)
213+
end
214+
215+
symtype(operation(O)) <: FnType && push!(vars, O)
216+
for arg in arguments(O)
217+
vars!(vars, arg; op=op)
218+
end
219+
220+
return vars
221+
end
222+
difference_vars(x::Sym) = vars(x; op=Difference)
223+
difference_vars(exprs::Symbolic) = vars(exprs; op=Difference)
224+
difference_vars(exprs) = vars(exprs; op=Difference)
225+
difference_vars!(vars, eq::Equation) = vars!(vars, eq; op=Difference)
226+
difference_vars!(vars, O) = vars!(vars, O; op=Difference)
227+
228+
function collect_operator_variables(sys, isop::Function)
229+
eqs = equations(sys)
230+
vars = Set()
231+
diffvars = Set()
232+
for eq in eqs
233+
vars!(vars, eq)
234+
for v in vars
235+
isop(v) || continue
236+
push!(diffvars, arguments(v)[1])
237+
end
238+
empty!(vars)
239+
end
240+
return diffvars
241+
end
242+
collect_differential_variables(sys) = collect_operator_variables(sys, isdifferential)
243+
collect_difference_variables(sys) = collect_operator_variables(sys, isdifference)
244+
245+
#
246+
247+
find_derivatives!(vars, expr::Equation, f=identity) = (find_derivatives!(vars, expr.lhs, f); find_derivatives!(vars, expr.rhs, f); vars)
248+
function find_derivatives!(vars, expr, f)
249+
!istree(O) && return vars
250+
operation(O) isa Differential && push!(vars, f(O))
251+
for arg in arguments(O)
252+
vars!(vars, arg)
253+
end
254+
return vars
255+
end
256+
257+
function collect_vars!(states, parameters, expr, iv)
258+
if expr isa Sym
259+
collect_var!(states, parameters, expr, iv)
260+
else
261+
for var in vars(expr)
262+
if istree(var) && operation(var) isa Differential
263+
var, _ = var_from_nested_derivative(var)
264+
end
265+
collect_var!(states, parameters, var, iv)
266+
end
267+
end
268+
return nothing
269+
end
270+
271+
function collect_var!(states, parameters, var, iv)
272+
isequal(var, iv) && return nothing
273+
if isparameter(var) || (istree(var) && isparameter(operation(var)))
274+
push!(parameters, var)
275+
else
276+
push!(states, var)
277+
end
278+
return nothing
279+
end

test/odesystem.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,45 @@ let
387387
sys = ODESystem(D.(x) .~ x)
388388
@test_nowarn structural_simplify(sys)
389389
end
390+
391+
# Mixed Difference Differential equations
392+
@parameters t a b c d
393+
@variables x(t) y(t)
394+
δ = Differential(t)
395+
D = Difference(t; dt=0.1)
396+
eqs = [
397+
δ(x) ~ a*x - b*x*y,
398+
δ(y) ~ -c*y + d*x*y,
399+
D(x) ~ y
400+
]
401+
de = ODESystem(eqs,t,[x,y],[a,b,c,d])
402+
@test generate_difference_cb(de) isa ModelingToolkit.DiffEqCallbacks.DiscreteCallback
403+
404+
# doesn't work with ODEFunction
405+
# prob = ODEProblem(ODEFunction{false}(de),[1.0,1.0],(0.0,1.0),[1.5,1.0,3.0,1.0])
406+
407+
prob = ODEProblem(de,[1.0,1.0],(0.0,1.0),[1.5,1.0,3.0,1.0], check_length=false)
408+
@test prob.kwargs[:difference_cb] isa ModelingToolkit.DiffEqCallbacks.DiscreteCallback
409+
410+
sol = solve(prob, Tsit5(); callback=prob.kwargs[:difference_cb], tstops=prob.tspan[1]:0.1:prob.tspan[2])
411+
412+
# Direct implementation
413+
function lotka(du,u,p,t)
414+
x = u[1]
415+
y = u[2]
416+
du[1] = p[1]*x - p[2]*x*y
417+
du[2] = -p[3]*y + p[4]*x*y
418+
end
419+
420+
prob2 = ODEProblem(lotka,[1.0,1.0],(0.0,1.0),[1.5,1.0,3.0,1.0])
421+
function periodic_difference_affect!(int)
422+
int.u += [int.u[2], 0]
423+
end
424+
425+
difference_cb = ModelingToolkit.PeriodicCallback(periodic_difference_affect!, 0.1)
426+
427+
sol2 = solve(prob2, Tsit5(); callback=difference_cb, tstops=collect(prob.tspan[1]:0.1:prob.tspan[2])[2:end]
428+
)
429+
430+
@test sol(0:0.01:1)[x] sol2(0:0.01:1)[1,:]
431+
@test sol(0:0.01:1)[y] sol2(0:0.01:1)[2,:]

0 commit comments

Comments
 (0)