Skip to content

Commit 6834f24

Browse files
authored
add nicer discrete sys -> prob constructor (#1268)
1 parent d877d28 commit 6834f24

File tree

3 files changed

+125
-24
lines changed

3 files changed

+125
-24
lines changed

src/systems/discrete_system/discrete_system.jl

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@ $(FIELDS)
1111
```
1212
using ModelingToolkit
1313
14-
@parameters σ ρ β
15-
@variables t x(t) y(t) z(t) next_x(t) next_y(t) next_z(t)
14+
@parameters σ=28.0 ρ=10.0 β=8/3 δt=0.1
15+
@variables t x(t)=1.0 y(t)=0.0 z(t)=0.0
16+
D = Difference(t; dt=δt)
1617
17-
eqs = [next_x ~ σ*(y-x),
18-
next_y ~ x*(ρ-z)-y,
19-
next_z ~ x*y - β*z]
18+
eqs = [D(x) ~ σ*(y-x),
19+
D(y) ~ x*(ρ-z)-y,
20+
D(z) ~ x*y - β*z]
2021
21-
@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β])
22+
@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]) # or
23+
@named de = DiscreteSystem(eqs)
2224
```
2325
"""
2426
struct DiscreteSystem <: AbstractTimeDependentSystem
@@ -45,26 +47,21 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
4547
"""
4648
systems::Vector{DiscreteSystem}
4749
"""
48-
default_u0: The default initial conditions to use when initial conditions
49-
are not supplied in `DiscreteSystem`.
50+
defaults: The default values to use when initial conditions and/or
51+
parameters are not supplied in `DiscreteProblem`.
5052
"""
51-
default_u0::Dict
52-
"""
53-
default_p: The default parameters to use when parameters are not supplied
54-
in `DiscreteSystem`.
55-
"""
56-
default_p::Dict
53+
defaults::Dict
5754
"""
5855
type: type of the system
5956
"""
6057
connection_type::Any
61-
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, default_u0, default_p, connection_type; checks::Bool = true)
58+
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, connection_type; checks::Bool = true)
6259
if checks
6360
check_variables(dvs, iv)
6461
check_parameters(ps, iv)
6562
all_dimensionless([dvs;ps;iv;ctrls]) ||check_units(discreteEqs)
6663
end
67-
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, default_u0, default_p, connection_type)
64+
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, connection_type)
6865
end
6966
end
7067

@@ -93,7 +90,7 @@ function DiscreteSystem(
9390
ctrl′ = value.(controls)
9491

9592
if !(isempty(default_u0) && isempty(default_p))
96-
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :ODESystem, force=true)
93+
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :DiscreteSystem, force=true)
9794
end
9895
defaults = todict(defaults)
9996
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
@@ -106,7 +103,46 @@ function DiscreteSystem(
106103
if length(unique(sysnames)) != length(sysnames)
107104
throw(ArgumentError("System names must be unique."))
108105
end
109-
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, default_u0, default_p, connection_type, kwargs...)
106+
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, defaults, connection_type, kwargs...)
107+
end
108+
109+
110+
function DiscreteSystem(eqs, iv=nothing; kwargs...)
111+
eqs = collect(eqs)
112+
# NOTE: this assumes that the order of algebric equations doesn't matter
113+
diffvars = OrderedSet()
114+
allstates = OrderedSet()
115+
ps = OrderedSet()
116+
# reorder equations such that it is in the form of `diffeq, algeeq`
117+
diffeq = Equation[]
118+
algeeq = Equation[]
119+
# initial loop for finding `iv`
120+
if iv === nothing
121+
for eq in eqs
122+
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
123+
iv = iv_from_nested_difference(eq.lhs)
124+
break
125+
end
126+
end
127+
end
128+
iv = value(iv)
129+
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
130+
for eq in eqs
131+
collect_vars_difference!(allstates, ps, eq.lhs, iv)
132+
collect_vars_difference!(allstates, ps, eq.rhs, iv)
133+
if isdifferenceeq(eq)
134+
diffvar, _ = var_from_nested_difference(eq.lhs)
135+
isequal(iv, iv_from_nested_difference(eq.lhs)) || throw(ArgumentError("A DiscreteSystem can only have one independent variable."))
136+
diffvar in diffvars && throw(ArgumentError("The difference variable $diffvar is not unique in the system of equations."))
137+
push!(diffvars, diffvar)
138+
push!(diffeq, eq)
139+
else
140+
push!(algeeq, eq)
141+
end
142+
end
143+
algevars = setdiff(allstates, diffvars)
144+
# the orders here are very important!
145+
return DiscreteSystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...)
110146
end
111147

112148
"""
@@ -123,16 +159,37 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
123159
ps = parameters(sys)
124160
eqs = equations(sys)
125161
eqs = linearize_eqs(sys, eqs)
126-
# defs = defaults(sys)
127-
t = get_iv(sys)
128-
u0 = varmap_to_vars(u0map,dvs)
162+
defs = defaults(sys)
163+
iv = get_iv(sys)
164+
165+
if parammap isa Dict
166+
u0defs = merge(parammap, defs)
167+
elseif eltype(parammap) <: Pair
168+
u0defs = merge(Dict(parammap), defs)
169+
elseif eltype(parammap) <: Number
170+
u0defs = merge(Dict(zip(ps, parammap)), defs)
171+
else
172+
u0defs = defs
173+
end
174+
if u0map isa Dict
175+
pdefs = merge(u0map, defs)
176+
elseif eltype(u0map) <: Pair
177+
pdefs = merge(Dict(u0map), defs)
178+
elseif eltype(u0map) <: Number
179+
pdefs = merge(Dict(zip(dvs, u0map)), defs)
180+
else
181+
pdefs = defs
182+
end
183+
184+
u0 = varmap_to_vars(u0map,dvs; defaults=u0defs)
185+
129186
rhss = [eq.rhs for eq in eqs]
130187
u = dvs
131-
p = varmap_to_vars(parammap,ps)
188+
p = varmap_to_vars(parammap,ps; defaults=pdefs)
132189

133190
f_gen = generate_function(sys; expression=Val{eval_expression}, expression_module=eval_module)
134191
f_oop, _ = (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen)
135-
f(u,p,t) = f_oop(u,p,t)
192+
f(u,p,iv) = f_oop(u,p,iv)
136193
DiscreteProblem(f,u0,tspan,p;kwargs...)
137194
end
138195

src/utils.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,15 @@ isdiffeq(eq) = isdifferential(eq.lhs)
225225
isdifference(expr) = istree(expr) && operation(expr) isa Difference
226226
isdifferenceeq(eq) = isdifference(eq.lhs)
227227

228+
iv_from_nested_difference(x::Term) = operation(x) isa Difference ? iv_from_nested_difference(arguments(x)[1]) : arguments(x)[1]
229+
iv_from_nested_difference(x::Sym) = x
230+
iv_from_nested_difference(x) = missing
231+
232+
var_from_nested_difference(x, i=0) = (missing, missing)
233+
var_from_nested_difference(x::Term,i=0) = operation(x) isa Difference ? var_from_nested_difference(arguments(x)[1], i + 1) : (x, i)
234+
var_from_nested_difference(x::Sym,i=0) = (x, i)
235+
236+
228237
isvariable(x::Num) = isvariable(value(x))
229238
function isvariable(x)
230239
x isa Symbolic || return false
@@ -305,6 +314,20 @@ function collect_vars!(states, parameters, expr, iv)
305314
return nothing
306315
end
307316

317+
function collect_vars_difference!(states, parameters, expr, iv)
318+
if expr isa Sym
319+
collect_var!(states, parameters, expr, iv)
320+
else
321+
for var in vars(expr)
322+
if istree(var) && operation(var) isa Difference
323+
var, _ = var_from_nested_difference(var)
324+
end
325+
collect_var!(states, parameters, var, iv)
326+
end
327+
end
328+
return nothing
329+
end
330+
308331
function collect_var!(states, parameters, var, iv)
309332
isequal(var, iv) && return nothing
310333
if isparameter(var) || (istree(var) && isparameter(operation(var)))

test/discretesystem.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end;
1212
# Independent and dependent variables and parameters
1313
@parameters t c nsteps δt β γ
1414
D = Difference(t; dt=0.1)
15-
@variables S(t) I(t) R(t) next_S(t) next_I(t) next_R(t)
15+
@variables S(t) I(t) R(t)
1616

1717
infection = rate_to_proportion*c*I/(S+I+R),δt)*S
1818
recovery = rate_to_proportion(γ,δt)*I
@@ -35,6 +35,27 @@ prob_map = DiscreteProblem(sys,u0,tspan,p)
3535
using OrdinaryDiffEq
3636
sol_map = solve(prob_map,FunctionMap());
3737

38+
# Using defaults constructor
39+
@parameters t c=10.0 nsteps=400 δt=0.1 β=0.05 γ=0.25
40+
Diff = Difference(t; dt=0.1)
41+
@variables S(t)=990.0 I(t)=10.0 R(t)=0.0
42+
43+
infection2 = rate_to_proportion*c*I/(S+I+R),δt)*S
44+
recovery2 = rate_to_proportion(γ,δt)*I
45+
46+
eqs2 = [D(S) ~ S-infection2,
47+
D(I) ~ I+infection2-recovery2,
48+
D(R) ~ R+recovery2]
49+
50+
@named sys = DiscreteSystem(eqs2; controls = [β, γ])
51+
@test ModelingToolkit.defaults(sys) != Dict()
52+
53+
prob_map2 = DiscreteProblem(sys,[],tspan)
54+
sol_map2 = solve(prob_map,FunctionMap());
55+
56+
@test sol_map.u == sol_map2.u
57+
@test sol_map.prob.p == sol_map2.prob.p
58+
3859
# Direct Implementation
3960

4061
function sir_map!(u_diff,u,p,t)

0 commit comments

Comments
 (0)