Skip to content

Commit 3cda516

Browse files
YingboMabaggepinnen
andcommitted
Add split_system that splits the system by their time domain
Co-authored-by: Fredrik Bagge Carlson <[email protected]>
1 parent 01f1412 commit 3cda516

File tree

3 files changed

+127
-5
lines changed

3 files changed

+127
-5
lines changed

src/systems/abstractsystem.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,15 +1031,20 @@ This will convert all `inputs` to parameters and allow them to be unconnected, i
10311031
simplification will allow models where `n_states = n_equations - n_inputs`.
10321032
"""
10331033
function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
1034-
simplify_constants = true, kwargs...)
1034+
kwargs...)
10351035
sys = expand_connections(sys)
10361036
sys isa DiscreteSystem && return sys
10371037
state = TearingState(sys)
1038+
structural_simplify!(state, io; simplify, kwargs...)
1039+
end
1040+
1041+
function structural_simplify!(state::TearingState, io = nothing; simplify = false,
1042+
kwargs...)
10381043
has_io = io !== nothing
10391044
has_io && markio!(state, io...)
10401045
state, input_idxs = inputs_to_parameters!(state, io)
10411046
sys, ag = alias_elimination!(state; kwargs...)
1042-
check_consistency(state, ag)
1047+
#check_consistency(state, ag)
10431048
sys = dummy_derivative(sys, state, ag; simplify)
10441049
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
10451050
@set! sys.observed = topsort_equations(observed(sys), fullstates)
@@ -1148,8 +1153,8 @@ end
11481153

11491154
function markio!(state, inputs, outputs; check = true)
11501155
fullvars = state.fullvars
1151-
inputset = Dict(inputs .=> false)
1152-
outputset = Dict(outputs .=> false)
1156+
inputset = Dict{Any, Bool}(i => false for i in inputs)
1157+
outputset = Dict{Any, Bool}(o => false for o in outputs)
11531158
for (i, v) in enumerate(fullvars)
11541159
if v in keys(inputset)
11551160
v = setio(v, true, false)

src/systems/clock_inference.jl

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function infer_clocks!(ci::ClockInference)
2828
@unpack ts, eq_domain, var_domain, inferred = ci
2929
@unpack fullvars = ts
3030
@unpack graph = ts.structure
31-
# TODO: add a graph time to do this lazily
31+
# TODO: add a graph type to do this lazily
3232
var_graph = SimpleGraph(ndsts(graph))
3333
for eq in 𝑠vertices(graph)
3434
vvs = 𝑠neighbors(graph, eq)
@@ -64,3 +64,78 @@ function infer_clocks!(ci::ClockInference)
6464

6565
return ci
6666
end
67+
68+
function resize_or_push!(v, val, idx)
69+
n = length(v)
70+
if idx > n
71+
for i in (n + 1):idx
72+
push!(v, Int[])
73+
end
74+
resize!(v, idx)
75+
end
76+
push!(v[idx], val)
77+
end
78+
79+
function split_system(ci::ClockInference)
80+
@unpack ts, eq_domain, var_domain, inferred = ci
81+
@unpack fullvars = ts
82+
@unpack graph = ts.structure
83+
continuous_id = 0
84+
clock_to_id = Dict{TimeDomain, Int}()
85+
id_to_clock = TimeDomain[]
86+
eq_to_cid = Vector{Int}(undef, nsrcs(graph))
87+
cid_to_eq = Vector{Int}[]
88+
var_to_cid = Vector{Int}(undef, ndsts(graph))
89+
cid_to_var = Vector{Int}[]
90+
cid = 0
91+
for (i, d) in enumerate(eq_domain)
92+
cid = get!(clock_to_id, d) do
93+
cid += 1
94+
push!(id_to_clock, d)
95+
if d isa Continuous
96+
continuous_id = cid
97+
end
98+
cid
99+
end
100+
eq_to_cid[i] = cid
101+
resize_or_push!(cid_to_eq, i, cid)
102+
end
103+
input_discrete = Int[]
104+
inputs = []
105+
for (i, d) in enumerate(var_domain)
106+
cid = get(clock_to_id, d, 0)
107+
@assert cid!==0 "Internal error!"
108+
var_to_cid[i] = cid
109+
v = fullvars[i]
110+
#TODO: remove Inferred*
111+
if cid == continuous_id && istree(v) && (o = operation(v)) isa Operator &&
112+
!(input_timedomain(o) isa Continuous)
113+
push!(input_discrete, i)
114+
push!(inputs, fullvars[i])
115+
end
116+
resize_or_push!(cid_to_var, i, cid)
117+
end
118+
119+
eqs = equations(ts)
120+
tss = similar(cid_to_eq, TearingState)
121+
for (id, ieqs) in enumerate(cid_to_eq)
122+
vars = cid_to_var[id]
123+
ts_i = ts
124+
fadj = Vector{Int}[]
125+
eqs_i = Equation[]
126+
var_set_i = BitSet(vars)
127+
ne = 0
128+
for eq_i in ieqs
129+
vars = copy(graph.fadjlist[eq_i])
130+
ne += length(vars)
131+
push!(fadj, vars)
132+
push!(eqs_i, eqs[eq_i])
133+
end
134+
@set! ts_i.structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
135+
@set! ts_i.sys.eqs = eqs_i
136+
tss[id] = ts_i
137+
end
138+
return tss, (; inputs, outputs = ())
139+
140+
#id_to_clock, cid_to_eq, cid_to_var
141+
end

test/clock.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,51 @@ eqs = [yd ~ Sample(t, dt)(y)
2121
y ~ x]
2222
@named sys = ODESystem(eqs)
2323
# compute equation and variables' time domains
24+
#TODO: test linearize
25+
26+
#=
27+
Differential(t)(x(t)) ~ u(t) - x(t)
28+
0 ~ Sample(Clock(t, 0.1))(y(t)) - yd(t)
29+
0 ~ kp*(r(t) - yd(t)) - ud(t)
30+
0 ~ Hold()(ud(t)) - u(t)
31+
0 ~ x(t) - y(t)
32+
33+
====
34+
By inference:
35+
36+
Differential(t)(x(t)) ~ u(t) - x(t)
37+
0 ~ Hold()(ud(t)) - u(t) # Hold()(ud(t)) is constant except in an event
38+
0 ~ x(t) - y(t)
39+
40+
0 ~ Sample(Clock(t, 0.1))(y(t)) - yd(t)
41+
0 ~ kp*(r(t) - yd(t)) - ud(t)
42+
43+
====
44+
45+
Differential(t)(x(t)) ~ u(t) - x(t)
46+
0 ~ Hold()(ud(t)) - u(t)
47+
0 ~ x(t) - y(t)
48+
49+
yd(t) := Sample(Clock(t, 0.1))(y(t))
50+
ud(t) := kp*(r(t) - yd(t))
51+
=#
52+
53+
#=
54+
D(x) ~ Shift(x, 0, dt) + 1 # this should never meet with continous variables
55+
=> (Shift(x, 0, dt) - Shift(x, -1, dt))/dt ~ Shift(x, 0, dt) + 1
56+
=> Shift(x, 0, dt) - Shift(x, -1, dt) ~ Shift(x, 0, dt) * dt + dt
57+
=> Shift(x, 0, dt) - Shift(x, 0, dt) * dt ~ Shift(x, -1, dt) + dt
58+
=> (1 - dt) * Shift(x, 0, dt) ~ Shift(x, -1, dt) + dt
59+
=> Shift(x, 0, dt) := (Shift(x, -1, dt) + dt) / (1 - dt) # Discrete system
60+
=#
2461

2562
ci, varmap = infer_clocks(sys)
2663
eqmap = ci.eq_domain
64+
tss, io = ModelingToolkit.split_system(deepcopy(ci))
65+
ts_c = deepcopy(tss[1])
66+
@set! ts_c.structure.solvable_graph = nothing
67+
sss, = ModelingToolkit.structural_simplify!(ts_c, io)
68+
@test equations(sss) == [D(x) ~ u - x]
2769

2870
d = Clock(t, dt)
2971
# Note that TearingState reorders the equations

0 commit comments

Comments
 (0)