Skip to content

Commit 9ef2981

Browse files
Merge pull request #2507 from AayushSabharwal/as/discrete-system
feat: initial implementation of new `DiscreteSystem`
2 parents de7ec66 + e9cc50e commit 9ef2981

19 files changed

+1038
-166
lines changed

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pages = [
88
"tutorials/modelingtoolkitize.md",
99
"tutorials/programmatically_generating.md",
1010
"tutorials/stochastic_diffeq.md",
11+
"tutorials/discrete_system.md",
1112
"tutorials/parameter_identifiability.md",
1213
"tutorials/bifurcation_diagram_computation.md",
1314
"tutorials/SampledData.md",

docs/src/tutorials/discrete_system.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# (Experimental) Modeling Discrete Systems
2+
3+
In this example, we will use the new [`DiscreteSystem`](@ref) API
4+
to create an SIR model.
5+
6+
```@example discrete
7+
using ModelingToolkit
8+
using ModelingToolkit: t_nounits as t
9+
using OrdinaryDiffEq: solve, FunctionMap
10+
11+
@inline function rate_to_proportion(r, t)
12+
1 - exp(-r * t)
13+
end
14+
@parameters c δt β γ
15+
@constants h = 1
16+
@variables S(t) I(t) R(t)
17+
k = ShiftIndex(t)
18+
infection = rate_to_proportion(
19+
β * c * I(k - 1) / (S(k - 1) * h + I(k - 1) + R(k - 1)), δt * h) * S(k - 1)
20+
recovery = rate_to_proportion(γ * h, δt) * I(k - 1)
21+
22+
# Equations
23+
eqs = [S(k) ~ S(k - 1) - infection * h,
24+
I(k) ~ I(k - 1) + infection - recovery,
25+
R(k) ~ R(k - 1) + recovery]
26+
@mtkbuild sys = DiscreteSystem(eqs, t)
27+
28+
u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0]
29+
p = [β => 0.05, c => 10.0, γ => 0.25, δt => 0.1]
30+
tspan = (0.0, 100.0)
31+
prob = DiscreteProblem(sys, u0, tspan, p)
32+
sol = solve(prob, FunctionMap())
33+
```
34+
35+
All shifts must be non-positive, i.e., discrete-time variables may only be indexed at index
36+
`k, k-1, k-2, ...`. If default values are provided, they are treated as the value of the
37+
variable at the previous timestep. For example, consider the following system to generate
38+
the Fibonacci series:
39+
40+
```@example discrete
41+
@variables x(t) = 1.0
42+
@mtkbuild sys = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t)
43+
```
44+
45+
The "default value" here should be interpreted as the value of `x` at all past timesteps.
46+
For example, here `x(k-1)` and `x(k-2)` will be `1.0`, and the inital value of `x(k)` will
47+
thus be `2.0`. During problem construction, the _past_ value of a variable should be
48+
provided. For example, providing `[x => 1.0]` while constructing this problem will error.
49+
Provide `[x(k-1) => 1.0]` instead. Note that values provided during problem construction
50+
_do not_ apply to the entire history. Hence, if `[x(k-1) => 2.0]` is provided, the value of
51+
`x(k-2)` will still be `1.0`.

src/ModelingToolkit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ include("systems/diffeqs/first_order_transform.jl")
147147
include("systems/diffeqs/modelingtoolkitize.jl")
148148
include("systems/diffeqs/basic_transformations.jl")
149149

150+
include("systems/discrete_system/discrete_system.jl")
151+
150152
include("systems/jumps/jumpsystem.jl")
151153

152154
include("systems/optimization/constraints_system.jl")
@@ -209,6 +211,7 @@ export ODESystem,
209211
export DAEFunctionExpr, DAEProblemExpr
210212
export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr
211213
export SystemStructure
214+
export DiscreteSystem, DiscreteProblem, DiscreteFunction, DiscreteFunctionExpr
212215
export JumpSystem
213216
export ODEProblem, SDEProblem
214217
export NonlinearFunction, NonlinearFunctionExpr

src/clock.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,8 @@ Base.hash(c::SolverStepClock, seed::UInt) = seed ⊻ 0x953d7b9a18874b91
146146
function Base.:(==)(c1::SolverStepClock, c2::SolverStepClock)
147147
((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t))
148148
end
149+
150+
struct IntegerSequence <: AbstractClock
151+
t::Union{Nothing, Symbolic}
152+
IntegerSequence(t::Union{Num, Symbolic}) = new(value(t))
153+
end

src/discretedomain.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ function (D::Shift)(x::Num, allow_zero = false)
3838
vt = value(x)
3939
if istree(vt)
4040
op = operation(vt)
41-
if op isa Shift
41+
if op isa Sample
42+
error("Cannot shift a `Sample`. Create a variable to represent the sampled value and shift that instead")
43+
elseif op isa Shift
4244
if D.t === nothing || isequal(D.t, op.t)
4345
arg = arguments(vt)[1]
4446
newsteps = D.steps + op.steps
@@ -168,6 +170,7 @@ struct ShiftIndex
168170
steps::Int
169171
ShiftIndex(clock::TimeDomain = Inferred(), steps::Int = 0) = new(clock, steps)
170172
ShiftIndex(t::Num, dt::Real, steps::Int = 0) = new(Clock(t, dt), steps)
173+
ShiftIndex(t::Num, steps::Int = 0) = new(IntegerSequence(t), steps)
171174
end
172175

173176
function (xn::Num)(k::ShiftIndex)

src/structural_transformation/symbolics_tearing.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
382382
dx = fullvars[dv]
383383
# add `x_t`
384384
order, lv = var_order(dv)
385-
x_t = lower_varname(fullvars[lv], iv, order)
386-
push!(fullvars, x_t)
385+
x_t = lower_varname_withshift(fullvars[lv], iv, order)
386+
push!(fullvars, simplify_shifts(x_t))
387387
v_t = length(fullvars)
388388
v_t_idx = add_vertex!(var_to_diff)
389389
add_vertex!(graph, DST)
@@ -437,11 +437,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
437437
# We cannot solve the differential variable like D(x)
438438
if isdervar(iv)
439439
order, lv = var_order(iv)
440-
dx = D(lower_varname(fullvars[lv], idep, order - 1))
441-
eq = dx ~ ModelingToolkit.fixpoint_sub(
440+
dx = D(simplify_shifts(lower_varname_withshift(
441+
fullvars[lv], idep, order - 1)))
442+
eq = dx ~ simplify_shifts(ModelingToolkit.fixpoint_sub(
442443
Symbolics.solve_for(neweqs[ieq],
443444
fullvars[iv]),
444-
total_sub)
445+
total_sub; operator = ModelingToolkit.Shift))
445446
for e in 𝑑neighbors(graph, iv)
446447
e == ieq && continue
447448
for v in 𝑠neighbors(graph, e)
@@ -450,7 +451,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
450451
rem_edge!(graph, e, iv)
451452
end
452453
push!(diff_eqs, eq)
453-
total_sub[eq.lhs] = eq.rhs
454+
total_sub[simplify_shifts(eq.lhs)] = eq.rhs
454455
push!(diffeq_idxs, ieq)
455456
push!(diff_vars, diff_to_var[iv])
456457
continue
@@ -469,7 +470,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
469470
neweq = var ~ ModelingToolkit.fixpoint_sub(
470471
simplify ?
471472
Symbolics.simplify(rhs) : rhs,
472-
total_sub)
473+
total_sub; operator = ModelingToolkit.Shift)
473474
push!(subeqs, neweq)
474475
push!(solved_equations, ieq)
475476
push!(solved_variables, iv)

src/structural_transformation/utils.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,40 @@ function numerical_nlsolve(f, u0, p)
412412
# TODO: robust initial guess, better debugging info, and residual check
413413
sol.u
414414
end
415+
416+
###
417+
### Misc
418+
###
419+
420+
function lower_varname_withshift(var, iv, order)
421+
order == 0 && return var
422+
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
423+
op = operation(var)
424+
return Shift(op.t, order)(var)
425+
end
426+
return lower_varname(var, iv, order)
427+
end
428+
429+
function isdoubleshift(var)
430+
return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
431+
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
432+
end
433+
434+
function simplify_shifts(var)
435+
ModelingToolkit.hasshift(var) || return var
436+
var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
437+
if isdoubleshift(var)
438+
op1 = operation(var)
439+
vv1 = arguments(var)[1]
440+
op2 = operation(vv1)
441+
vv2 = arguments(vv1)[1]
442+
s1 = op1.steps
443+
s2 = op2.steps
444+
t1 = op1.t
445+
t2 = op2.t
446+
return simplify_shifts(ModelingToolkit.Shift(t1 === nothing ? t2 : t1, s1 + s2)(vv2))
447+
else
448+
return similarterm(var, operation(var), simplify_shifts.(arguments(var)),
449+
Symbolics.symtype(var); metadata = unwrap(var).metadata)
450+
end
451+
end

src/systems/alias_elimination.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ function observed2graph(eqs, unknowns)
453453
lhs_j === nothing &&
454454
throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in unknowns."))
455455
assigns[i] = lhs_j
456-
vs = vars(eq.rhs)
456+
vs = vars(eq.rhs; op = Symbolics.Operator)
457457
for v in vs
458458
j = get(v2j, v, nothing)
459459
j !== nothing && add_edge!(graph, i, j)
@@ -463,11 +463,11 @@ function observed2graph(eqs, unknowns)
463463
return graph, assigns
464464
end
465465

466-
function fixpoint_sub(x, dict)
467-
y = fast_substitute(x, dict)
466+
function fixpoint_sub(x, dict; operator = Nothing)
467+
y = fast_substitute(x, dict; operator)
468468
while !isequal(x, y)
469469
y = x
470-
x = fast_substitute(y, dict)
470+
x = fast_substitute(y, dict; operator)
471471
end
472472

473473
return x

src/systems/clock_inference.jl

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ function split_system(ci::ClockInference{S}) where {S}
133133
tss = similar(cid_to_eq, S)
134134
for (id, ieqs) in enumerate(cid_to_eq)
135135
ts_i = system_subset(ts, ieqs)
136-
@set! ts_i.structure.only_discrete = id != continuous_id
136+
if id != continuous_id
137+
ts_i = shift_discrete_system(ts_i)
138+
@set! ts_i.structure.only_discrete = true
139+
end
137140
tss[id] = ts_i
138141
end
139142
return tss, inputs, continuous_id, id_to_clock
@@ -148,7 +151,7 @@ function generate_discrete_affect(
148151
end
149152
use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing
150153
out = Sym{Any}(:out)
151-
appended_parameters = parameters(syss[continuous_id])
154+
appended_parameters = full_parameters(syss[continuous_id])
152155
offset = length(appended_parameters)
153156
param_to_idx = if use_index_cache
154157
Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
@@ -180,40 +183,46 @@ function generate_discrete_affect(
180183
disc_to_cont_idxs = Int[]
181184
end
182185
for v in inputs[continuous_id]
183-
vv = arguments(v)[1]
184-
if vv in fullvars
185-
push!(needed_disc_to_cont_obs, vv)
186+
_v = arguments(v)[1]
187+
if _v in fullvars
188+
push!(needed_disc_to_cont_obs, _v)
189+
push!(disc_to_cont_idxs, param_to_idx[v])
190+
continue
191+
end
192+
193+
# If the held quantity is calculated through observed
194+
# it will be shifted forward by 1
195+
_v = Shift(get_iv(sys), 1)(_v)
196+
if _v in fullvars
197+
push!(needed_disc_to_cont_obs, _v)
186198
push!(disc_to_cont_idxs, param_to_idx[v])
199+
continue
187200
end
188201
end
189-
append!(appended_parameters, input, unknowns(sys))
202+
append!(appended_parameters, input)
190203
cont_to_disc_obs = build_explicit_observed_function(
191204
use_index_cache ? osys : syss[continuous_id],
192205
needed_cont_to_disc_obs,
193206
throw = false,
194207
expression = true,
195208
output_type = SVector)
196-
@set! sys.ps = appended_parameters
197209
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
198210
throw = false,
199211
expression = true,
200212
output_type = SVector,
201-
ps = reorder_parameters(osys, full_parameters(sys)))
213+
op = Shift,
214+
ps = reorder_parameters(osys, appended_parameters))
202215
ni = length(input)
203216
ns = length(unknowns(sys))
204217
disc = Func(
205218
[
206219
out,
207220
DestructuredArgs(unknowns(osys)),
208-
if use_index_cache
209-
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))
210-
else
211-
(DestructuredArgs(appended_parameters),)
212-
end...,
221+
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))...,
213222
get_iv(sys)
214223
],
215224
[],
216-
let_block)
225+
let_block) |> toexpr
217226
if use_index_cache
218227
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
219228
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
@@ -235,8 +244,14 @@ function generate_discrete_affect(
235244
end
236245
empty_disc = isempty(disc_range)
237246
disc_init = if use_index_cache
238-
:(function (p, t)
247+
:(function (u, p, t)
248+
c2d_obs = $cont_to_disc_obs
239249
d2c_obs = $disc_to_cont_obs
250+
result = c2d_obs(u, p..., t)
251+
for (val, i) in zip(result, $cont_to_disc_idxs)
252+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
253+
end
254+
240255
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
241256
result = d2c_obs(disc_state, p..., t)
242257
for (val, i) in zip(result, $disc_to_cont_idxs)
@@ -248,11 +263,14 @@ function generate_discrete_affect(
248263
repack(discretes) # to force recalculation of dependents
249264
end)
250265
else
251-
:(function (p, t)
266+
:(function (u, p, t)
267+
c2d_obs = $cont_to_disc_obs
252268
d2c_obs = $disc_to_cont_obs
269+
c2d_view = view(p, $cont_to_disc_idxs)
253270
d2c_view = view(p, $disc_to_cont_idxs)
254-
disc_state = view(p, $disc_range)
255-
copyto!(d2c_view, d2c_obs(disc_state, p, t))
271+
disc_unknowns = view(p, $disc_range)
272+
copyto!(c2d_view, c2d_obs(u, p, t))
273+
copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))
256274
end)
257275
end
258276

@@ -277,9 +295,6 @@ function generate_discrete_affect(
277295
# TODO: find a way to do this without allocating
278296
disc = $disc
279297

280-
push!(saved_values.t, t)
281-
push!(saved_values.saveval, $save_vec)
282-
283298
# Write continuous into to discrete: handles `Sample`
284299
# Write discrete into to continuous
285300
# Update discrete unknowns
@@ -329,6 +344,10 @@ function generate_discrete_affect(
329344
:(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)))
330345
end
331346
)
347+
348+
push!(saved_values.t, t)
349+
push!(saved_values.saveval, $save_vec)
350+
332351
# @show "after d2c", p
333352
$(
334353
if use_index_cache

0 commit comments

Comments
 (0)