Skip to content

Commit 305a8bd

Browse files
authored
Merge pull request #1913 from SciML/myb_fb/simplify_clock
Add `split_system` that splits the system by their time domain
2 parents bbe5081 + b361973 commit 305a8bd

File tree

9 files changed

+267
-47
lines changed

9 files changed

+267
-47
lines changed

src/inputoutput.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,13 @@ function inputs_to_parameters!(state::TransformationState, io)
302302
ps = parameters(sys)
303303

304304
if io !== nothing
305+
inputs, = io
305306
# Change order of new parameters to correspond to user-provided order in argument `inputs`
306307
d = Dict{Any, Int}()
307308
for (i, inp) in enumerate(new_parameters)
308309
d[inp] = i
309310
end
310-
permutation = [d[i] for i in io.inputs]
311+
permutation = [d[i] for i in inputs]
311312
new_parameters = new_parameters[permutation]
312313
end
313314

src/structural_transformation/pantelides.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ function pantelides!(state::TransformationState, ag::Union{AliasGraph, Nothing}
131131
pathfound = construct_augmenting_path!(var_eq_matching, graph, eq′,
132132
v -> varwhitelist[v], vcolor, ecolor)
133133
pathfound && break # terminating condition
134+
if is_only_discrete(state.structure)
135+
error("The discrete system has high structural index. This is not supported.")
136+
end
134137
for var in eachindex(vcolor)
135138
vcolor[var] || continue
136139
if var_to_diff[var] === nothing

src/structural_transformation/partial_state_selection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ function dummy_derivative_graph!(state::TransformationState, jac = nothing,
154154
(ag, diff_va) = (nothing, nothing);
155155
state_priority = nothing, kwargs...)
156156
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
157-
var_eq_matching = complete(pantelides!(state, ag))
158157
complete!(state.structure)
158+
var_eq_matching = complete(pantelides!(state, ag))
159159
dummy_derivative_graph!(state.structure, var_eq_matching, jac, (ag, diff_va),
160160
state_priority)
161161
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function full_equations(sys::AbstractSystem; simplify = false)
9292
@unpack subs = substitutions
9393
solved = Dict(eq.lhs => eq.rhs for eq in subs)
9494
neweqs = map(equations(sys)) do eq
95-
if isdiffeq(eq)
95+
if istree(eq.lhs) && operation(eq.lhs) isa Union{Shift, Differential}
9696
return tearing_sub(eq.lhs, solved, simplify) ~ tearing_sub(eq.rhs, solved,
9797
simplify)
9898
else
@@ -262,7 +262,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
262262

263263
if ModelingToolkit.has_iv(state.sys)
264264
iv = get_iv(state.sys)
265-
D = Differential(iv)
265+
if is_only_discrete(state.structure)
266+
D = Shift(iv, 1)
267+
else
268+
D = Differential(iv)
269+
end
266270
else
267271
iv = D = nothing
268272
end
@@ -628,10 +632,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
628632
@set! state.structure.var_to_diff = var_to_diff
629633
@set! state.structure.eq_to_diff = eq_to_diff
630634
@set! state.fullvars = fullvars = fullvars[invvarsperm]
635+
ispresent = let var_to_diff = var_to_diff, graph = graph
636+
i -> (!isempty(𝑑neighbors(graph, i)) ||
637+
(var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
638+
end
631639

632640
sys = state.sys
633641
@set! sys.eqs = neweqs
634-
@set! sys.states = [v for (i, v) in enumerate(fullvars) if diff_to_var[i] === nothing]
642+
@set! sys.states = Any[v
643+
for (i, v) in enumerate(fullvars)
644+
if diff_to_var[i] === nothing && ispresent(i)]
635645
removed_obs_set = BitSet(removed_obs)
636646
var_to_idx = Dict(reverse(en) for en in enumerate(fullvars))
637647
# Make sure differentiated variables don't appear in observed equations

src/structural_transformation/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ function check_consistency(state::TearingState, ag = nothing)
7676

7777
unassigned_var = []
7878
for (vj, eq) in enumerate(extended_var_eq_matching)
79-
if eq === unassigned && (ag === nothing || !haskey(ag, vj))
79+
if eq === unassigned && (ag === nothing || !haskey(ag, vj)) &&
80+
!isempty(𝑑neighbors(graph, vj))
8081
push!(unassigned_var, fullvars[vj])
8182
end
8283
end

src/systems/abstractsystem.jl

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,22 +1031,11 @@ This will convert all `inputs` to parameters and allow them to be unconnected, i
10311031
simplification will allow models where `n_states = n_equations - n_inputs`.
10321032
"""
10331033
function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
1034-
simplify_constants = true, check_consistency = true, kwargs...)
1034+
kwargs...)
10351035
sys = expand_connections(sys)
10361036
sys isa DiscreteSystem && return sys
10371037
state = TearingState(sys)
1038-
has_io = io !== nothing
1039-
has_io && markio!(state, io...)
1040-
state, input_idxs = inputs_to_parameters!(state, io)
1041-
sys, ag = alias_elimination!(state; kwargs...)
1042-
if check_consistency
1043-
ModelingToolkit.check_consistency(state, ag)
1044-
end
1045-
sys = dummy_derivative(sys, state, ag; simplify)
1046-
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
1047-
@set! sys.observed = topsort_equations(observed(sys), fullstates)
1048-
invalidate_cache!(sys)
1049-
return has_io ? (sys, input_idxs) : sys
1038+
structural_simplify!(state, io; simplify, kwargs...)
10501039
end
10511040

10521041
function eliminate_constants(sys::AbstractSystem)
@@ -1063,7 +1052,7 @@ end
10631052

10641053
function io_preprocessing(sys::AbstractSystem, inputs,
10651054
outputs; simplify = false, kwargs...)
1066-
sys, input_idxs = structural_simplify(sys, (; inputs, outputs); simplify, kwargs...)
1055+
sys, input_idxs = structural_simplify(sys, (inputs, outputs); simplify, kwargs...)
10671056

10681057
eqs = equations(sys)
10691058
alg_start_idx = findfirst(!isdiffeq, eqs)
@@ -1150,8 +1139,8 @@ end
11501139

11511140
function markio!(state, inputs, outputs; check = true)
11521141
fullvars = state.fullvars
1153-
inputset = Dict(inputs .=> false)
1154-
outputset = Dict(outputs .=> false)
1142+
inputset = Dict{Any, Bool}(i => false for i in inputs)
1143+
outputset = Dict{Any, Bool}(o => false for o in outputs)
11551144
for (i, v) in enumerate(fullvars)
11561145
if v in keys(inputset)
11571146
v = setio(v, true, false)
@@ -1166,9 +1155,13 @@ function markio!(state, inputs, outputs; check = true)
11661155
fullvars[i] = v
11671156
end
11681157
end
1169-
check && (all(values(inputset)) ||
1170-
error("Some specified inputs were not found in system. The following Dict indicates the found variables ",
1171-
inputset))
1158+
if check
1159+
ikeys = keys(filter(!last, inputset))
1160+
if !isempty(ikeys)
1161+
error("Some specified inputs were not found in system. The following variables were not found ",
1162+
ikeys)
1163+
end
1164+
end
11721165
check && (all(values(outputset)) ||
11731166
error("Some specified outputs were not found in system. The following Dict indicates the found variables ",
11741167
outputset))

src/systems/clock_inference.jl

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function infer_clocks!(ci::ClockInference)
2828
@unpack ts, eq_domain, var_domain, inferred = ci
2929
@unpack fullvars = ts
3030
@unpack graph = ts.structure
31-
# TODO: add a graph time to do this lazily
31+
# TODO: add a graph type to do this lazily
3232
var_graph = SimpleGraph(ndsts(graph))
3333
for eq in 𝑠vertices(graph)
3434
vvs = 𝑠neighbors(graph, eq)
@@ -58,9 +58,97 @@ function infer_clocks!(ci::ClockInference)
5858
vd = var_domain[v]
5959
eqs = 𝑑neighbors(graph, v)
6060
isempty(eqs) && continue
61-
eq = first(eqs)
62-
eq_domain[eq] = vd
61+
#eq = first(eqs)
62+
for eq in eqs
63+
eq_domain[eq] = vd
64+
end
6365
end
6466

6567
return ci
6668
end
69+
70+
function resize_or_push!(v, val, idx)
71+
n = length(v)
72+
if idx > n
73+
for i in (n + 1):idx
74+
push!(v, Int[])
75+
end
76+
resize!(v, idx)
77+
end
78+
push!(v[idx], val)
79+
end
80+
81+
function split_system(ci::ClockInference)
82+
@unpack ts, eq_domain, var_domain, inferred = ci
83+
@unpack fullvars = ts
84+
@unpack graph, var_to_diff = ts.structure
85+
continuous_id = Ref(0)
86+
clock_to_id = Dict{TimeDomain, Int}()
87+
id_to_clock = TimeDomain[]
88+
eq_to_cid = Vector{Int}(undef, nsrcs(graph))
89+
cid_to_eq = Vector{Int}[]
90+
var_to_cid = Vector{Int}(undef, ndsts(graph))
91+
cid_to_var = Vector{Int}[]
92+
cid_counter = Ref(0)
93+
for (i, d) in enumerate(eq_domain)
94+
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
95+
continuous_id = continuous_id
96+
97+
get!(clock_to_id, d) do
98+
cid = (cid_counter[] += 1)
99+
push!(id_to_clock, d)
100+
if d isa Continuous
101+
continuous_id[] = cid
102+
end
103+
cid
104+
end
105+
end
106+
eq_to_cid[i] = cid
107+
resize_or_push!(cid_to_eq, i, cid)
108+
end
109+
continuous_id = continuous_id[]
110+
input_idxs = map(_ -> Int[], 1:cid_counter[])
111+
inputs = map(_ -> Any[], 1:cid_counter[])
112+
nvv = length(var_domain)
113+
for i in 1:nvv
114+
d = var_domain[i]
115+
cid = get(clock_to_id, d, 0)
116+
@assert cid!==0 "Internal error! Variable $(fullvars[i]) doesn't have a inferred time domain."
117+
var_to_cid[i] = cid
118+
v = fullvars[i]
119+
#TODO: remove Inferred*
120+
if istree(v) && (o = operation(v)) isa Operator &&
121+
input_timedomain(o) != output_timedomain(o)
122+
push!(input_idxs[cid], i)
123+
push!(inputs[cid], fullvars[i])
124+
end
125+
resize_or_push!(cid_to_var, i, cid)
126+
end
127+
128+
eqs = equations(ts)
129+
tss = similar(cid_to_eq, TearingState)
130+
for (id, ieqs) in enumerate(cid_to_eq)
131+
vars = cid_to_var[id]
132+
ts_i = ts
133+
fadj = Vector{Int}[]
134+
eqs_i = Equation[]
135+
eq_to_diff = DiffGraph(length(ieqs))
136+
ne = 0
137+
for (j, eq_i) in enumerate(ieqs)
138+
vars = copy(graph.fadjlist[eq_i])
139+
ne += length(vars)
140+
push!(fadj, vars)
141+
push!(eqs_i, eqs[eq_i])
142+
eq_to_diff[j] = ts_i.structure.eq_to_diff[eq_i]
143+
end
144+
@set! ts_i.structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
145+
@set! ts_i.structure.only_discrete = id != continuous_id
146+
@set! ts_i.sys.eqs = eqs_i
147+
@set! ts_i.structure.eq_to_diff = eq_to_diff
148+
tss[id] = ts_i
149+
# TODO: just mark current and sample variables as inputs
150+
end
151+
return tss, inputs
152+
153+
#id_to_clock, cid_to_eq, cid_to_var
154+
end

src/systems/systemstructure.jl

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ function quick_cancel_expr(expr)
2525
kws...))(expr)
2626
end
2727

28-
export SystemStructure, TransformationState, TearingState
28+
export SystemStructure, TransformationState, TearingState, structural_simplify!
2929
export initialize_system_structure, find_linear_equations
30-
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs
30+
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs, is_only_discrete
3131
export dervars_range, diffvars_range, algvars_range
3232
export DiffGraph, complete!
3333

@@ -143,7 +143,9 @@ Base.@kwdef mutable struct SystemStructure
143143
# or as `torn` to assert that tearing has run.
144144
graph::BipartiteGraph{Int, Nothing}
145145
solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing}
146+
only_discrete::Bool
146147
end
148+
is_only_discrete(s::SystemStructure) = s.only_discrete
147149
isdervar(s::SystemStructure, i) = invview(s.var_to_diff)[i] !== nothing
148150
function isalgvar(s::SystemStructure, i)
149151
s.var_to_diff[i] === nothing &&
@@ -258,6 +260,27 @@ function TearingState(sys; quick_cancel = false, check = true)
258260
idx = addvar!(dvar)
259261
end
260262

263+
dvar = var
264+
idx = varidx
265+
if ModelingToolkit.isoperator(dvar, ModelingToolkit.Shift)
266+
if !(idx in dervaridxs)
267+
push!(dervaridxs, idx)
268+
end
269+
op = operation(dvar)
270+
tt = op.t
271+
steps = op.steps
272+
v = arguments(dvar)[1]
273+
for s in (steps - 1):-1:1
274+
sf = Shift(tt, s)
275+
dvar = sf(v)
276+
idx = addvar!(dvar)
277+
if !(idx in dervaridxs)
278+
push!(dervaridxs, idx)
279+
end
280+
end
281+
idx = addvar!(v)
282+
end
283+
261284
if istree(var) && operation(var) isa Symbolics.Operator &&
262285
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
263286
set_incidence = false
@@ -281,7 +304,7 @@ function TearingState(sys; quick_cancel = false, check = true)
281304
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
282305
for dervaridx in dervaridxs
283306
dervar = fullvars[dervaridx]
284-
diffvar = arguments(dervar)[1]
307+
diffvar = lower_order_var(dervar)
285308
if !(diffvar in sorted_fullvars)
286309
push!(sorted_fullvars, diffvar)
287310
end
@@ -300,24 +323,12 @@ function TearingState(sys; quick_cancel = false, check = true)
300323
var_to_diff = DiffGraph(nvars, true)
301324
for dervaridx in dervaridxs
302325
dervar = fullvars[dervaridx]
303-
diffvar = arguments(dervar)[1]
326+
diffvar = lower_order_var(dervar)
304327
diffvaridx = var2idx[diffvar]
305328
push!(diffvars, diffvar)
306329
var_to_diff[diffvaridx] = dervaridx
307330
end
308331

309-
#=
310-
algvars = setdiff(states(sys), diffvars)
311-
for algvar in algvars
312-
# it could be that a variable appeared in the states, but never appeared
313-
# in the equations.
314-
algvaridx = get(var2idx, algvar, 0)
315-
#if algvaridx == 0
316-
# check ? throw(InvalidSystemException("The system is missing an equation for $algvar.")) : return nothing
317-
#end
318-
end
319-
=#
320-
321332
graph = BipartiteGraph(neqs, nvars, Val(false))
322333
for (ie, vars) in enumerate(symbolic_incidence), v in vars
323334
jv = var2idx[v]
@@ -330,7 +341,23 @@ function TearingState(sys; quick_cancel = false, check = true)
330341

331342
return TearingState(sys, fullvars,
332343
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
333-
complete(graph), nothing), Any[])
344+
complete(graph), nothing, false), Any[])
345+
end
346+
347+
function lower_order_var(dervar)
348+
if isdifferential(dervar)
349+
diffvar = arguments(dervar)[1]
350+
else # shift
351+
s = operation(dervar)
352+
step = s.steps - 1
353+
vv = arguments(dervar)[1]
354+
if step >= 1
355+
diffvar = Shift(s.t, step)(vv)
356+
else
357+
diffvar = vv
358+
end
359+
end
360+
diffvar
334361
end
335362

336363
using .BipartiteGraphs: Label, BipartiteAdjacencyList
@@ -424,4 +451,21 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
424451
complete(ms.var_eq_matching, nsrcs(graph))))
425452
end
426453

454+
# TODO: clean up
455+
function structural_simplify!(state::TearingState, io = nothing; simplify = false,
456+
check_consistency = true, kwargs...)
457+
has_io = io !== nothing
458+
has_io && ModelingToolkit.markio!(state, io...)
459+
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
460+
sys, ag = ModelingToolkit.alias_elimination!(state; kwargs...)
461+
if check_consistency
462+
ModelingToolkit.check_consistency(state, ag)
463+
end
464+
sys = ModelingToolkit.dummy_derivative(sys, state, ag; simplify)
465+
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
466+
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullstates)
467+
ModelingToolkit.invalidate_cache!(sys)
468+
return has_io ? (sys, input_idxs) : sys
469+
end
470+
427471
end # module

0 commit comments

Comments
 (0)