Skip to content

Commit 460252a

Browse files
YingboMabaggepinnen
andcommitted
Add special handling for discrete only systems
Co-authored-by: Fredrik Bagge Carlson <[email protected]>
1 parent 522fbb9 commit 460252a

File tree

6 files changed

+53
-30
lines changed

6 files changed

+53
-30
lines changed

src/structural_transformation/pantelides.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ function pantelides!(state::TransformationState, ag::Union{AliasGraph, Nothing}
131131
pathfound = construct_augmenting_path!(var_eq_matching, graph, eq′,
132132
v -> varwhitelist[v], vcolor, ecolor)
133133
pathfound && break # terminating condition
134+
if is_only_discrete(state.structure)
135+
error("The discrete system has high structural index. This is not supported.")
136+
end
134137
for var in eachindex(vcolor)
135138
vcolor[var] || continue
136139
if var_to_diff[var] === nothing

src/structural_transformation/partial_state_selection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ function dummy_derivative_graph!(state::TransformationState, jac = nothing,
154154
(ag, diff_va) = (nothing, nothing);
155155
state_priority = nothing, kwargs...)
156156
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
157-
var_eq_matching = complete(pantelides!(state, ag))
158157
complete!(state.structure)
158+
var_eq_matching = complete(pantelides!(state, ag))
159159
dummy_derivative_graph!(state.structure, var_eq_matching, jac, (ag, diff_va),
160160
state_priority)
161161
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function full_equations(sys::AbstractSystem; simplify = false)
9292
@unpack subs = substitutions
9393
solved = Dict(eq.lhs => eq.rhs for eq in subs)
9494
neweqs = map(equations(sys)) do eq
95-
if isdiffeq(eq)
95+
if istree(eq.lhs) && operation(eq.lhs) isa Union{Shift, Differential}
9696
return tearing_sub(eq.lhs, solved, simplify) ~ tearing_sub(eq.rhs, solved,
9797
simplify)
9898
else
@@ -262,7 +262,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
262262

263263
if ModelingToolkit.has_iv(state.sys)
264264
iv = get_iv(state.sys)
265-
D = Differential(iv)
265+
if is_only_discrete(state.structure)
266+
D = Shift(iv, 1)
267+
else
268+
D = Differential(iv)
269+
end
266270
else
267271
iv = D = nothing
268272
end
@@ -631,7 +635,9 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
631635

632636
sys = state.sys
633637
@set! sys.eqs = neweqs
634-
@set! sys.states = [v for (i, v) in enumerate(fullvars) if diff_to_var[i] === nothing]
638+
@set! sys.states = Any[v
639+
for (i, v) in enumerate(fullvars)
640+
if diff_to_var[i] === nothing && !isempty(𝑑neighbors(graph, i))]
635641
removed_obs_set = BitSet(removed_obs)
636642
var_to_idx = Dict(reverse(en) for en in enumerate(fullvars))
637643
# Make sure differentiated variables don't appear in observed equations

src/systems/clock_inference.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ end
8181
function split_system(ci::ClockInference)
8282
@unpack ts, eq_domain, var_domain, inferred = ci
8383
@unpack fullvars = ts
84-
@unpack graph = ts.structure
84+
@unpack graph, var_to_diff = ts.structure
8585
continuous_id = Ref(0)
8686
clock_to_id = Dict{TimeDomain, Int}()
8787
id_to_clock = TimeDomain[]
@@ -106,11 +106,14 @@ function split_system(ci::ClockInference)
106106
eq_to_cid[i] = cid
107107
resize_or_push!(cid_to_eq, i, cid)
108108
end
109+
continuous_id = continuous_id[]
109110
input_idxs = map(_ -> Int[], 1:cid_counter[])
110111
inputs = map(_ -> Any[], 1:cid_counter[])
111-
for (i, d) in enumerate(var_domain)
112+
nvv = length(var_domain)
113+
for i in 1:nvv
114+
d = var_domain[i]
112115
cid = get(clock_to_id, d, 0)
113-
@assert cid!==0 "Internal error!"
116+
@assert cid!==0 "Internal error! Variable $(fullvars[i]) doesn't have a inferred time domain."
114117
var_to_cid[i] = cid
115118
v = fullvars[i]
116119
#TODO: remove Inferred*
@@ -139,10 +142,11 @@ function split_system(ci::ClockInference)
139142
eq_to_diff[j] = ts_i.structure.eq_to_diff[eq_i]
140143
end
141144
@set! ts_i.structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
145+
@set! ts_i.structure.only_discrete = id != continuous_id
142146
@set! ts_i.sys.eqs = eqs_i
143147
@set! ts_i.structure.eq_to_diff = eq_to_diff
144148
tss[id] = ts_i
145-
# TODO: just mark past and sample variables as inputs
149+
# TODO: just mark current and sample variables as inputs
146150
end
147151
return tss, inputs
148152

src/systems/systemstructure.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727

2828
export SystemStructure, TransformationState, TearingState, structural_simplify!
2929
export initialize_system_structure, find_linear_equations
30-
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs
30+
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs, is_only_discrete
3131
export dervars_range, diffvars_range, algvars_range
3232
export DiffGraph, complete!
3333

@@ -143,7 +143,9 @@ Base.@kwdef mutable struct SystemStructure
143143
# or as `torn` to assert that tearing has run.
144144
graph::BipartiteGraph{Int, Nothing}
145145
solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing}
146+
only_discrete::Bool
146147
end
148+
is_only_discrete(s::SystemStructure) = s.only_discrete
147149
isdervar(s::SystemStructure, i) = invview(s.var_to_diff)[i] !== nothing
148150
function isalgvar(s::SystemStructure, i)
149151
s.var_to_diff[i] === nothing &&
@@ -277,9 +279,6 @@ function TearingState(sys; quick_cancel = false, check = true)
277279
end
278280
end
279281
idx = addvar!(v)
280-
#if !(idx in dervaridxs)
281-
# push!(dervaridxs, idx)
282-
#end
283282
end
284283

285284
if istree(var) && operation(var) isa Symbolics.Operator &&
@@ -305,7 +304,7 @@ function TearingState(sys; quick_cancel = false, check = true)
305304
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
306305
for dervaridx in dervaridxs
307306
dervar = fullvars[dervaridx]
308-
diffvar = arguments(dervar)[1]
307+
diffvar = lower_order_var(dervar)
309308
if !(diffvar in sorted_fullvars)
310309
push!(sorted_fullvars, diffvar)
311310
end
@@ -324,24 +323,12 @@ function TearingState(sys; quick_cancel = false, check = true)
324323
var_to_diff = DiffGraph(nvars, true)
325324
for dervaridx in dervaridxs
326325
dervar = fullvars[dervaridx]
327-
diffvar = arguments(dervar)[1]
326+
diffvar = lower_order_var(dervar)
328327
diffvaridx = var2idx[diffvar]
329328
push!(diffvars, diffvar)
330329
var_to_diff[diffvaridx] = dervaridx
331330
end
332331

333-
#=
334-
algvars = setdiff(states(sys), diffvars)
335-
for algvar in algvars
336-
# it could be that a variable appeared in the states, but never appeared
337-
# in the equations.
338-
algvaridx = get(var2idx, algvar, 0)
339-
#if algvaridx == 0
340-
# check ? throw(InvalidSystemException("The system is missing an equation for $algvar.")) : return nothing
341-
#end
342-
end
343-
=#
344-
345332
graph = BipartiteGraph(neqs, nvars, Val(false))
346333
for (ie, vars) in enumerate(symbolic_incidence), v in vars
347334
jv = var2idx[v]
@@ -354,7 +341,23 @@ function TearingState(sys; quick_cancel = false, check = true)
354341

355342
return TearingState(sys, fullvars,
356343
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
357-
complete(graph), nothing), Any[])
344+
complete(graph), nothing, false), Any[])
345+
end
346+
347+
function lower_order_var(dervar)
348+
if isdifferential(dervar)
349+
diffvar = arguments(dervar)[1]
350+
else # shift
351+
s = operation(dervar)
352+
step = s.steps - 1
353+
vv = arguments(dervar)[1]
354+
if step >= 1
355+
diffvar = Shift(s.t, step)(vv)
356+
else
357+
diffvar = vv
358+
end
359+
end
360+
diffvar
358361
end
359362

360363
using .BipartiteGraphs: Label, BipartiteAdjacencyList

test/clock.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ eqmap = ci.eq_domain
6666
tss, inputs = ModelingToolkit.split_system(deepcopy(ci))
6767
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[1]), (inputs[1], ()))
6868
@test equations(sss) == [D(x) ~ u - x]
69-
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[2]), (inputs[2], ()))
69+
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[2]), (inputs[2], ()),
70+
check_consistency = false)
7071
@test isempty(equations(sss))
7172
@test observed(sss) == [r ~ 1.0; yd ~ Sample(t, dt)(y); ud ~ kp * (r - yd)]
7273

@@ -104,15 +105,21 @@ eqs = [yd ~ Sample(t, dt)(y)
104105
y ~ x
105106
z(k + 2) ~ z(k) + yd
106107
#=
107-
z(k + 2) ~ z(k)
108+
z(k + 2) ~ z(k) + yd
108109
=>
109-
z′(k + 1) ~ z(k)
110+
z′(k + 1) ~ z(k) + yd
110111
z(k + 1) ~ z′(k)
111112
=#
112113
]
113114
@named sys = ODESystem(eqs)
114115
ci, varmap = infer_clocks(sys)
115116
tss, inputs = ModelingToolkit.split_system(deepcopy(ci))
117+
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[2]), (inputs[2], ()),
118+
check_consistency = false)
119+
@test length(states(sss)) == 2
120+
z, z_t = states(sss)
121+
S = Shift(t, 1)
122+
@test full_equations(sss) == [S(z) ~ z_t; S(z_t) ~ z + Sample(t, dt)(y)]
116123

117124
@info "Testing multi-rate hybrid system"
118125
dt = 0.1

0 commit comments

Comments
 (0)