Skip to content

Commit 3b2a68b

Browse files
committed
Enable usage of difference operator in ODESystem
1 parent 6d52839 commit 3b2a68b

File tree

6 files changed

+127
-82
lines changed

6 files changed

+127
-82
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1111
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
12+
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
1213
DiffEqJump = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12"
1314
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1415
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -48,6 +49,7 @@ ArrayInterface = "2.8, 3.0"
4849
ConstructionBase = "1"
4950
DataStructures = "0.17, 0.18"
5051
DiffEqBase = "6.54.0"
52+
DiffEqCallbacks = "2.16"
5153
DiffEqJump = "6.7.5"
5254
DiffRules = "0.1, 1.0"
5355
Distributions = "0.23, 0.24, 0.25"

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using DataStructures
1717
using SpecialFunctions, NaNMath
1818
using RuntimeGeneratedFunctions
1919
using Base.Threads
20+
using DiffEqCallbacks
2021
import MacroTools: splitdef, combinedef, postwalk, striplines
2122
import Libdl
2223
using DocStringExtensions
@@ -180,6 +181,7 @@ export calculate_hessian, generate_hessian
180181
export calculate_massmatrix, generate_diffusion_function
181182
export stochastic_integral_transform
182183
export initialize_system_structure
184+
export generate_difference_cb
183185

184186
export BipartiteGraph, equation_dependencies, variable_dependencies
185187
export eqeq_dependencies, varvar_dependencies

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function generate_function(
9494
foreach(check_derivative_variables, eqs)
9595
# substitute x(t) by just x
9696
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
97-
[eq.rhs for eq in eqs]
97+
[eq.rhs for eq in eqs if isdiffeq(eq)]
9898
#rhss = Let(obss, rhss)
9999

100100
# TODO: add an optional check on the ordering of observed equations
@@ -109,6 +109,38 @@ function generate_function(
109109
end
110110
end
111111

112+
function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = parameters(sys);
113+
kwargs...)
114+
eqs = equations(sys)
115+
foreach(check_difference_variables, eqs)
116+
# substitute x(t) by just x
117+
118+
# map(x -> isempty(x) ? Val{0} : first(x),
119+
rhss = [
120+
begin
121+
ind = findfirst(eq -> isdifference(eq.lhs) && isequal(arguments(eq.lhs)[1], s), eqs)
122+
ind === nothing ? Val{0} : eqs[ind].rhs
123+
end
124+
for s in dvs ]
125+
126+
u = map(x->time_varying_as_func(value(x), sys), dvs)
127+
p = map(x->time_varying_as_func(value(x), sys), ps)
128+
t = get_iv(sys)
129+
130+
f_oop, f_iip = build_function(rhss, u, p, t; kwargs...)
131+
132+
f = @RuntimeGeneratedFunction(@__MODULE__, f_oop)
133+
134+
function cb_affect!(int)
135+
int.u += f(int.u, int.p, int.t)
136+
end
137+
138+
dts = [ operation(eq.lhs).dt for eq in eqs if isdifferenceeq(eq)]
139+
all(dts .== dts[1]) || error("All difference variables should have same time steps.")
140+
141+
PeriodicCallback(cb_affect!, dts[1])
142+
end
143+
112144
function time_varying_as_func(x, sys)
113145
# if something is not x(t) (the current state)
114146
# but is `x(t-1)` or something like that, pass in `x` as a callable function rather
@@ -713,6 +745,3 @@ end
713745
function SteadyStateProblemExpr(sys::AbstractODESystem, args...; kwargs...)
714746
SteadyStateProblemExpr{true}(sys, args...; kwargs...)
715747
end
716-
717-
isdifferential(expr) = istree(expr) && operation(expr) isa Differential
718-
isdiffeq(eq) = isdifferential(eq.lhs)

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -133,42 +133,6 @@ function ODESystem(
133133
ODESystem(deqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
134134
end
135135

136-
vars(x::Sym) = Set([x])
137-
vars(exprs::Symbolic) = vars([exprs])
138-
vars(exprs) = foldl(vars!, exprs; init = Set())
139-
vars!(vars, eq::Equation) = (vars!(vars, eq.lhs); vars!(vars, eq.rhs); vars)
140-
function vars!(vars, O)
141-
if isa(O, Sym)
142-
return push!(vars, O)
143-
end
144-
!istree(O) && return vars
145-
146-
operation(O) isa Differential && return push!(vars, O)
147-
148-
if operation(O) === (getindex) &&
149-
first(arguments(O)) isa Symbolic
150-
151-
return push!(vars, O)
152-
end
153-
154-
symtype(operation(O)) <: FnType && push!(vars, O)
155-
for arg in arguments(O)
156-
vars!(vars, arg)
157-
end
158-
159-
return vars
160-
end
161-
162-
find_derivatives!(vars, expr::Equation, f=identity) = (find_derivatives!(vars, expr.lhs, f); find_derivatives!(vars, expr.rhs, f); vars)
163-
function find_derivatives!(vars, expr, f)
164-
!istree(O) && return vars
165-
operation(O) isa Differential && push!(vars, f(O))
166-
for arg in arguments(O)
167-
vars!(vars, arg)
168-
end
169-
return vars
170-
end
171-
172136
function ODESystem(eqs, iv=nothing; kwargs...)
173137
# NOTE: this assumes that the order of algebric equations doesn't matter
174138
diffvars = OrderedSet()
@@ -206,30 +170,6 @@ function ODESystem(eqs, iv=nothing; kwargs...)
206170
return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...)
207171
end
208172

209-
function collect_vars!(states, parameters, expr, iv)
210-
if expr isa Sym
211-
collect_var!(states, parameters, expr, iv)
212-
else
213-
for var in vars(expr)
214-
if istree(var) && operation(var) isa Differential
215-
var, _ = var_from_nested_derivative(var)
216-
end
217-
collect_var!(states, parameters, var, iv)
218-
end
219-
end
220-
return nothing
221-
end
222-
223-
function collect_var!(states, parameters, var, iv)
224-
isequal(var, iv) && return nothing
225-
if isparameter(var) || (istree(var) && isparameter(operation(var)))
226-
push!(parameters, var)
227-
else
228-
push!(states, var)
229-
end
230-
return nothing
231-
end
232-
233173
# NOTE: equality does not check cached Jacobian
234174
function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
235175
iv1 = independent_variable(sys1)
@@ -317,21 +257,6 @@ function _eq_unordered(a, b)
317257
return true
318258
end
319259

320-
function collect_differential_variables(sys::ODESystem)
321-
eqs = equations(sys)
322-
vars = Set()
323-
diffvars = Set()
324-
for eq in eqs
325-
vars!(vars, eq)
326-
for v in vars
327-
isdifferential(v) || continue
328-
push!(diffvars, arguments(v)[1])
329-
end
330-
empty!(vars)
331-
end
332-
return diffvars
333-
end
334-
335260
# We have a stand-alone function to convert a `NonlinearSystem` or `ODESystem`
336261
# to an `ODESystem` to connect systems, and we later can reply on
337262
# `structural_simplify` to convert `ODESystem`s to `NonlinearSystem`s.

src/systems/discrete_system/discrete_system.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,6 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
121121
DiscreteProblem(f,u0,tspan,p;kwargs...)
122122
end
123123

124-
isdifference(expr) = istree(expr) && operation(expr) isa Difference
125-
isdifferenceeq(eq) = isdifference(eq.lhs)
126-
127124
check_difference_variables(eq) = check_operator_variables(eq, Difference)
128125

129126
function generate_function(

src/utils.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,93 @@ function check_operator_variables(eq, op::Type, expr=eq.rhs)
187187
end
188188
foreach(expr -> check_operator_variables(eq, op, expr), arguments(expr))
189189
end
190+
191+
isdifferential(expr) = istree(expr) && operation(expr) isa Differential
192+
isdiffeq(eq) = isdifferential(eq.lhs)
193+
194+
isdifference(expr) = istree(expr) && operation(expr) isa Difference
195+
isdifferenceeq(eq) = isdifference(eq.lhs)
196+
197+
vars(x::Sym; op=Differential) = Set([x])
198+
vars(exprs::Symbolic; op=Differential) = vars([exprs]; op=op)
199+
vars(exprs; op=Differential) = foldl((x, y) -> vars!(x, y; op=op), exprs; init = Set())
200+
vars!(vars, eq::Equation; op=Differential) = (vars!(vars, eq.lhs; op=op); vars!(vars, eq.rhs; op=op); vars)
201+
function vars!(vars, O; op=Differential)
202+
if isa(O, Sym)
203+
return push!(vars, O)
204+
end
205+
!istree(O) && return vars
206+
207+
operation(O) isa op && return push!(vars, O)
208+
209+
if operation(O) === (getindex) &&
210+
first(arguments(O)) isa Symbolic
211+
212+
return push!(vars, O)
213+
end
214+
215+
symtype(operation(O)) <: FnType && push!(vars, O)
216+
for arg in arguments(O)
217+
vars!(vars, arg; op=op)
218+
end
219+
220+
return vars
221+
end
222+
difference_vars(x::Sym) = vars(x; op=Difference)
223+
difference_vars(exprs::Symbolic) = vars(exprs; op=Difference)
224+
difference_vars(exprs) = vars(exprs; op=Difference)
225+
difference_vars!(vars, eq::Equation) = vars!(vars, eq; op=Difference)
226+
difference_vars!(vars, O) = vars!(vars, O; op=Difference)
227+
228+
function collect_operator_variables(sys, isop::Function)
229+
eqs = equations(sys)
230+
vars = Set()
231+
diffvars = Set()
232+
for eq in eqs
233+
vars!(vars, eq)
234+
for v in vars
235+
isop(v) || continue
236+
push!(diffvars, arguments(v)[1])
237+
end
238+
empty!(vars)
239+
end
240+
return diffvars
241+
end
242+
collect_differential_variables(sys) = collect_operator_variables(sys, isdifferential)
243+
collect_difference_variables(sys) = collect_operator_variables(sys, isdifference)
244+
245+
#
246+
247+
find_derivatives!(vars, expr::Equation, f=identity) = (find_derivatives!(vars, expr.lhs, f); find_derivatives!(vars, expr.rhs, f); vars)
248+
function find_derivatives!(vars, expr, f)
249+
!istree(O) && return vars
250+
operation(O) isa Differential && push!(vars, f(O))
251+
for arg in arguments(O)
252+
vars!(vars, arg)
253+
end
254+
return vars
255+
end
256+
257+
function collect_vars!(states, parameters, expr, iv)
258+
if expr isa Sym
259+
collect_var!(states, parameters, expr, iv)
260+
else
261+
for var in vars(expr)
262+
if istree(var) && operation(var) isa Differential
263+
var, _ = var_from_nested_derivative(var)
264+
end
265+
collect_var!(states, parameters, var, iv)
266+
end
267+
end
268+
return nothing
269+
end
270+
271+
function collect_var!(states, parameters, var, iv)
272+
isequal(var, iv) && return nothing
273+
if isparameter(var) || (istree(var) && isparameter(operation(var)))
274+
push!(parameters, var)
275+
else
276+
push!(states, var)
277+
end
278+
return nothing
279+
end

0 commit comments

Comments
 (0)