Skip to content

Commit 2a97325

Browse files
committed
feat: initialization of DiscreteSystem
1 parent 406d0a8 commit 2a97325

File tree

5 files changed

+193
-68
lines changed

5 files changed

+193
-68
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 85 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,18 @@ in order to properly generate the difference equations.
341341
342342
In the system x(k) ~ x(k-1) + x(k-2), becomes Shift(t, 1)(x(t)) ~ x(t) + Shift(t, -1)(x(t))
343343
344-
The lowest-order term is Shift(t, k)(x(t)), instead of x(t).
345-
As such we actually want dummy variables for the k-1 lowest order terms instead of the k-1 highest order terms.
344+
The lowest-order term is Shift(t, k)(x(t)), instead of x(t). As such we actually want
345+
dummy variables for the k-1 lowest order terms instead of the k-1 highest order terms.
346346
347347
Shift(t, -1)(x(t)) -> x\_{t-1}(t)
348348
349-
Since Shift(t, -1)(x) is not a derivative, it is directly substituted in `fullvars`. No equation or variable is added for it.
349+
Since Shift(t, -1)(x) is not a derivative, it is directly substituted in `fullvars`.
350+
No equation or variable is added for it.
350351
351-
For ODESystems D(D(D(x))) in equations is recursively substituted as D(x) ~ x_t, D(x_t) ~ x_tt, etc. The analogue for discrete systems, Shift(t, 1)(Shift(t,1)(Shift(t,1)(Shift(t, -3)(x(t))))) does not actually appear. So `total_sub` in generate_system_equations` is directly initialized with all of the lowered variables `Shift(t, -3)(x) -> x_t-3(t)`, etc.
352+
For ODESystems D(D(D(x))) in equations is recursively substituted as D(x) ~ x_t, D(x_t) ~ x_tt, etc.
353+
The analogue for discrete systems, Shift(t, 1)(Shift(t,1)(Shift(t,1)(Shift(t, -3)(x(t)))))
354+
does not actually appear. So `total_sub` in generate_system_equations` is directly
355+
initialized with all of the lowered variables `Shift(t, -3)(x) -> x_t-3(t)`, etc.
352356
=#
353357
"""
354358
Generate new derivative variables for the system.
@@ -358,7 +362,7 @@ Effects on the system structure:
358362
- neweqs: add the identity equations for the new variables, D(x) ~ x_t
359363
- graph: update graph with the new equations and variables, and their connections
360364
- solvable_graph:
361-
- var_eq_matching: match D(x) to the added identity equation
365+
- var_eq_matching: match D(x) to the added identity equation D(x) ~ x_t
362366
"""
363367
function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching; mm = nothing, iv = nothing, D = nothing)
364368
@unpack fullvars, sys, structure = ts
@@ -406,7 +410,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
406410
end
407411

408412
"""
409-
Check if there's `D(x) = x_t` already.
413+
Check if there's `D(x) ~ x_t` already.
410414
"""
411415
function find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm)
412416
for eq in 𝑑neighbors(solvable_graph, dv)
@@ -427,6 +431,10 @@ function find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm)
427431
return nothing
428432
end
429433

434+
"""
435+
Add a dummy derivative variable x_t corresponding to symbolic variable D(x)
436+
which has index dv in `fullvars`. Return the new index of x_t.
437+
"""
430438
function add_dd_variable!(s::SystemStructure, fullvars, x_t, dv)
431439
push!(fullvars, simplify_shifts(x_t))
432440
v_t = length(fullvars)
@@ -439,7 +447,11 @@ function add_dd_variable!(s::SystemStructure, fullvars, x_t, dv)
439447
v_t
440448
end
441449

442-
# dv = index of D(x), v_t = index of x_t
450+
"""
451+
Add the equation D(x) - x_t ~ 0 to `neweqs`. `dv` and `v_t` are the indices
452+
of the higher-order derivative variable and the newly-introduced dummy
453+
derivative variable. Return the index of the new equation in `neweqs`.
454+
"""
443455
function add_dd_equation!(s::SystemStructure, neweqs, eq, dv, v_t)
444456
push!(neweqs, eq)
445457
add_vertex!(s.graph, SRC)
@@ -452,8 +464,33 @@ function add_dd_equation!(s::SystemStructure, neweqs, eq, dv, v_t)
452464
end
453465

454466
"""
455-
Solve the solvable equations of the system and generate differential (or discrete)
456-
equations in terms of the selected states.
467+
Solve the equations in `neweqs` to obtain the final equations of the
468+
system.
469+
470+
For each equation of `neweqs`, do one of the following:
471+
1. If the equation is solvable for a differentiated variable D(x),
472+
then solve for D(x), and add D(x) ~ sol as a differential equation
473+
of the system.
474+
2. If the equation is solvable for an un-differentiated variable x,
475+
solve for x and then add x ~ sol as a solved equation. These will
476+
become observables.
477+
3. If the equation is not solvable, add it as an algebraic equation.
478+
479+
Solved equations are added to `total_sub`. Occurrences of differential
480+
or solved variables on the RHS of the final equations will get substituted.
481+
The topological sort of the equations ensures that variables are solved for
482+
before they appear in equations.
483+
484+
Reorder the equations and unknowns to be:
485+
[diffeqs; ...]
486+
[diffvars; ...]
487+
such that the mass matrix is:
488+
[I 0
489+
0 0].
490+
491+
Order the new equations and variables such that the differential equations
492+
and variables come first. Return the new equations, the solved equations,
493+
the new orderings, and the number of solved variables and equations.
457494
"""
458495
function generate_system_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, iv = nothing, D = nothing)
459496
@unpack fullvars, sys, structure = state
@@ -550,6 +587,9 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
550587
return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars), length(solved_vars_set)
551588
end
552589

590+
"""
591+
Occurs when a variable D(x) occurs in a non-differential system.
592+
"""
553593
struct UnexpectedDifferentialError
554594
eq::Equation
555595
end
@@ -558,12 +598,20 @@ function Base.showerror(io::IO, err::UnexpectedDifferentialError)
558598
error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(err.eq)")
559599
end
560600

601+
"""
602+
Generate a first-order differential equation whose LHS is `dx`.
603+
604+
`var` and `dx` represent the same variable, but `var` may be a higher-order differential and `dx` is always first-order. For example, if `var` is D(D(x)), then `dx` would be `D(x_t)`. Solve `eq` for `var`, substitute previously solved variables, and return the differential equation.
605+
"""
561606
function make_differential_equation(var, dx, eq, total_sub)
562607
dx ~ simplify_shifts(Symbolics.fixpoint_sub(
563608
Symbolics.symbolic_linear_solve(eq, var),
564609
total_sub; operator = ModelingToolkit.Shift))
565610
end
566611

612+
"""
613+
Generate an algebraic equation. Substitute solved variables into `eq` and return the equation.
614+
"""
567615
function make_algebraic_equation(eq, total_sub)
568616
rhs = eq.rhs
569617
if !(eq.lhs isa Number && eq.lhs == 0)
@@ -572,6 +620,9 @@ function make_algebraic_equation(eq, total_sub)
572620
0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub))
573621
end
574622

623+
"""
624+
Solve equation `eq` for `var`, substitute previously solved variables, and return the solved equation.
625+
"""
575626
function make_solved_equation(var, eq, total_sub; simplify = false)
576627
residual = eq.lhs - eq.rhs
577628
a, b, islinear = linear_expansion(residual, var)
@@ -591,17 +642,13 @@ function make_solved_equation(var, eq, total_sub; simplify = false)
591642
end
592643

593644
"""
594-
Reorder the equations and unknowns to be:
595-
[diffeqs; ...]
596-
[diffvars; ...]
597-
such that the mass matrix is:
598-
[I 0
599-
0 0].
600-
601-
Update the state to account for the new ordering and equations.
645+
Given the ordering returned by `generate_system_equations!`, update the
646+
tearing state to account for the new order. Permute the variables and equations.
647+
Eliminate the solved variables and equations from the graph and permute the
648+
graph's vertices to account for the new variable/equation ordering.
602649
"""
603650
# TODO: BLT sorting
604-
function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
651+
function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_ordering, nsolved_eq, nsolved_var)
605652
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
606653

607654
eqsperm = zeros(Int, nsrcs(graph))
@@ -616,7 +663,7 @@ function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_or
616663
# Contract the vertices in the structure graph to make the structure match
617664
# the new reality of the system we've just created.
618665
new_graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm,
619-
nelim_eq, nelim_var)
666+
nsolved_eq, nsolved_var)
620667

621668
new_var_to_diff = complete(DiffGraph(length(var_ordering)))
622669
for (v, d) in enumerate(var_to_diff)
@@ -643,7 +690,7 @@ function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_or
643690
end
644691

645692
"""
646-
Set the system equations, unknowns, observables post-tearing.
693+
Update the system equations, unknowns, and observables after simplification.
647694
"""
648695
function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
649696
cse_hack = true, array_hack = true)
@@ -685,16 +732,10 @@ function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dumm
685732
sys = schedule(sys)
686733
end
687734

688-
# Terminology and Definition:
689-
# A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
690-
# characterize variables in `u(t)` into two classes: differential variables
691-
# (denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential
692-
# variables are marked as `SelectedState` and they are differentiated in the
693-
# DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually
694-
# appear in the system. Algebraic variables are variables that are not
695-
# differential variables.
696735

697-
# Give the order of the variable indexed by dv
736+
"""
737+
Give the order of the variable indexed by dv.
738+
"""
698739
function var_order(dv, diff_to_var)
699740
order = 0
700741
while (dv′ = diff_to_var[dv]) !== nothing
@@ -704,6 +745,21 @@ function var_order(dv, diff_to_var)
704745
order, dv
705746
end
706747

748+
"""
749+
Main internal function for structural simplification for DAE systems and discrete systems.
750+
Generate dummy derivative variables, new equations in terms of variables, return updated
751+
system and tearing state.
752+
753+
Terminology and Definition:
754+
755+
A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can
756+
characterize variables in `u(t)` into two classes: differential variables
757+
(denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential
758+
variables are marked as `SelectedState` and they are differentiated in the
759+
DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually
760+
appear in the system. Algebraic variables are variables that are not
761+
differential variables.
762+
"""
707763
function tearing_reassemble(state::TearingState, var_eq_matching,
708764
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
709765
extra_vars = Int[]

src/structural_transformation/utils.jl

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,12 @@ end
449449
### Misc
450450
###
451451

452-
# For discrete variables. Turn Shift(t, k)(x(t)) into xₜ₋ₖ(t)
452+
"""
453+
Handle renaming variable names for discrete structural simplification. Three cases:
454+
- positive shift: do nothing
455+
- zero shift: x(t) => Shift(t, 0)(x(t))
456+
- negative shift: rename the variable
457+
"""
453458
function lower_shift_varname(var, iv)
454459
op = operation(var)
455460
op isa Shift || return Shift(iv, 0)(var, true) # hack to prevent simplification of x(t) - x(t)
@@ -460,30 +465,46 @@ function lower_shift_varname(var, iv)
460465
end
461466
end
462467

463-
function shift2term(var)
468+
"""
469+
Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t).
470+
"""
471+
function shift2term(var)
464472
backshift = operation(var).steps
465473
iv = operation(var).t
466-
num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift)))
467-
ds = join([Char(0x209c), Char(0x208b), num])
468-
#ds = "$iv-$(-backshift)"
469-
#d_separator = 'ˍ'
470-
471-
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
472-
O = only(arguments(var))
473-
oldop = operation(O)
474-
newname = Symbol(string(nameof(oldop)), ds)
475-
else
476-
O = var
477-
oldop = operation(var)
478-
varname = split(string(nameof(oldop)), d_separator)[1]
479-
newname = Symbol(varname, d_separator, ds)
480-
end
474+
num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift))) # subscripted number, e.g. ₁
475+
ds = join([Char(0x209c), Char(0x208b), num])
476+
# Char(0x209c) = ₜ
477+
# Char(0x208b) = ₋ (subscripted minus)
478+
479+
O = only(arguments(var))
480+
oldop = operation(O)
481+
newname = Symbol(string(nameof(oldop)), ds)
482+
481483
newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
482484
newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
483485
newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O)
486+
newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift)
484487
return newvar
485488
end
486489

490+
function term2shift(var)
491+
var = Symbolics.unwrap(var)
492+
name = Symbolics.getname(var)
493+
O = only(arguments(var))
494+
oldop = operation(O)
495+
iv = only(arguments(x))
496+
# Split on ₋
497+
if occursin(Char(0x208b), name)
498+
substrings = split(name, Char(0x208b))
499+
shift = last(split(name, Char(0x208b)))
500+
newname = join(substrings[1:end-1])[1:end-1]
501+
newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
502+
return Shift(iv, -shift)(newvar)
503+
else
504+
return var
505+
end
506+
end
507+
487508
function isdoubleshift(var)
488509
return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
489510
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)

src/systems/discrete_system/discrete_system.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,19 @@ function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
270270
v = u0map[k]
271271
if !((op = operation(k)) isa Shift)
272272
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
273+
elseif op.steps > 0
274+
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).")
273275
end
276+
274277
updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
275278
end
276279
for var in unknowns(sys)
277280
op = operation(var)
278-
haskey(updated, var) && continue
279281
root = getunshifted(var)
282+
shift = getshift(var)
280283
isnothing(root) && continue
281-
haskey(defs, root) || error("Initial condition for $root not provided.")
284+
(haskey(updated, Shift(iv, shift)(root)) || haskey(updated, var)) && continue
285+
haskey(defs, root) || error("Initial condition for $var not provided.")
282286
updated[var] = defs[root]
283287
end
284288
return updated

src/variables.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ struct VariableOutput end
66
struct VariableIrreducible end
77
struct VariableStatePriority end
88
struct VariableMisc end
9+
# Metadata for renamed shift variables xₜ₋₁
910
struct VariableUnshifted end
11+
struct VariableShift end
1012
Symbolics.option_to_metadata_type(::Val{:unit}) = VariableUnit
1113
Symbolics.option_to_metadata_type(::Val{:connect}) = VariableConnectType
1214
Symbolics.option_to_metadata_type(::Val{:input}) = VariableInput
@@ -15,6 +17,7 @@ Symbolics.option_to_metadata_type(::Val{:irreducible}) = VariableIrreducible
1517
Symbolics.option_to_metadata_type(::Val{:state_priority}) = VariableStatePriority
1618
Symbolics.option_to_metadata_type(::Val{:misc}) = VariableMisc
1719
Symbolics.option_to_metadata_type(::Val{:unshifted}) = VariableUnshifted
20+
Symbolics.option_to_metadata_type(::Val{:shift}) = VariableShift
1821

1922
"""
2023
dump_variable_metadata(var)
@@ -97,7 +100,7 @@ struct Stream <: AbstractConnectType end # special stream connector
97100
98101
Get the connect type of x. See also [`hasconnect`](@ref).
99102
"""
100-
getconnect(x) = getconnect(unwrap(x))
103+
getconnect(x::Num) = getconnect(unwrap(x))
101104
getconnect(x::Symbolic) = Symbolics.getmetadata(x, VariableConnectType, nothing)
102105
"""
103106
hasconnect(x)
@@ -264,7 +267,7 @@ end
264267
end
265268

266269
struct IsHistory end
267-
ishistory(x) = ishistory(unwrap(x))
270+
ishistory(x::Num) = ishistory(unwrap(x))
268271
ishistory(x::Symbolic) = getmetadata(x, IsHistory, false)
269272
hist(x, t) = wrap(hist(unwrap(x), t))
270273
function hist(x::Symbolic, t)
@@ -575,7 +578,7 @@ end
575578
Fetch any miscellaneous data associated with symbolic variable `x`.
576579
See also [`hasmisc(x)`](@ref).
577580
"""
578-
getmisc(x) = getmisc(unwrap(x))
581+
getmisc(x::Num) = getmisc(unwrap(x))
579582
getmisc(x::Symbolic) = Symbolics.getmetadata(x, VariableMisc, nothing)
580583
"""
581584
hasmisc(x)
@@ -594,7 +597,7 @@ setmisc(x, miscdata) = setmetadata(x, VariableMisc, miscdata)
594597
595598
Fetch the unit associated with variable `x`. This function is a metadata getter for an individual variable, while `get_unit` is used for unit inference on more complicated sdymbolic expressions.
596599
"""
597-
getunit(x) = getunit(unwrap(x))
600+
getunit(x::Num) = getunit(unwrap(x))
598601
getunit(x::Symbolic) = Symbolics.getmetadata(x, VariableUnit, nothing)
599602
"""
600603
hasunit(x)
@@ -603,5 +606,8 @@ Check if the variable `x` has a unit.
603606
"""
604607
hasunit(x) = getunit(x) !== nothing
605608

606-
getunshifted(x) = getunshifted(unwrap(x))
609+
getunshifted(x::Num) = getunshifted(unwrap(x))
607610
getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing)
611+
612+
getshift(x::Num) = getshift(unwrap(x))
613+
getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0)

0 commit comments

Comments
 (0)