Skip to content

Commit 8c2e426

Browse files
YingboMabaggepinnen
andcommitted
Compute inputs for each clock
Co-authored-by: Fredrik Bagge Carlson <[email protected]>
1 parent 19e950f commit 8c2e426

File tree

5 files changed

+62
-44
lines changed

5 files changed

+62
-44
lines changed

src/inputoutput.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,13 @@ function inputs_to_parameters!(state::TransformationState, io)
302302
ps = parameters(sys)
303303

304304
if io !== nothing
305+
inputs, = io
305306
# Change order of new parameters to correspond to user-provided order in argument `inputs`
306307
d = Dict{Any, Int}()
307308
for (i, inp) in enumerate(new_parameters)
308309
d[inp] = i
309310
end
310-
permutation = [d[i] for i in io.inputs]
311+
permutation = [d[i] for i in inputs]
311312
new_parameters = new_parameters[permutation]
312313
end
313314

src/systems/abstractsystem.jl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,20 +1038,6 @@ function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false
10381038
structural_simplify!(state, io; simplify, kwargs...)
10391039
end
10401040

1041-
function structural_simplify!(state::TearingState, io = nothing; simplify = false,
1042-
kwargs...)
1043-
has_io = io !== nothing
1044-
has_io && markio!(state, io...)
1045-
state, input_idxs = inputs_to_parameters!(state, io)
1046-
sys, ag = alias_elimination!(state; kwargs...)
1047-
#check_consistency(state, ag)
1048-
sys = dummy_derivative(sys, state, ag; simplify)
1049-
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
1050-
@set! sys.observed = topsort_equations(observed(sys), fullstates)
1051-
invalidate_cache!(sys)
1052-
return has_io ? (sys, input_idxs) : sys
1053-
end
1054-
10551041
function eliminate_constants(sys::AbstractSystem)
10561042
if has_eqs(sys)
10571043
eqs = get_eqs(sys)
@@ -1066,7 +1052,7 @@ end
10661052

10671053
function io_preprocessing(sys::AbstractSystem, inputs,
10681054
outputs; simplify = false, kwargs...)
1069-
sys, input_idxs = structural_simplify(sys, (; inputs, outputs); simplify, kwargs...)
1055+
sys, input_idxs = structural_simplify(sys, (inputs, outputs); simplify, kwargs...)
10701056

10711057
eqs = equations(sys)
10721058
alg_start_idx = findfirst(!isdiffeq, eqs)
@@ -1169,9 +1155,13 @@ function markio!(state, inputs, outputs; check = true)
11691155
fullvars[i] = v
11701156
end
11711157
end
1172-
check && (all(values(inputset)) ||
1173-
error("Some specified inputs were not found in system. The following Dict indicates the found variables ",
1174-
inputset))
1158+
if check
1159+
ikeys = keys(filter(!last, inputset))
1160+
if !isempty(ikeys)
1161+
error("Some specified inputs were not found in system. The following variables were not found ",
1162+
ikeys)
1163+
end
1164+
end
11751165
check && (all(values(outputset)) ||
11761166
error("Some specified outputs were not found in system. The following Dict indicates the found variables ",
11771167
outputset))

src/systems/clock_inference.jl

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ function infer_clocks!(ci::ClockInference)
5858
vd = var_domain[v]
5959
eqs = 𝑑neighbors(graph, v)
6060
isempty(eqs) && continue
61-
eq = first(eqs)
62-
eq_domain[eq] = vd
61+
#eq = first(eqs)
62+
for eq in eqs
63+
eq_domain[eq] = vd
64+
end
6365
end
6466

6567
return ci
@@ -80,38 +82,42 @@ function split_system(ci::ClockInference)
8082
@unpack ts, eq_domain, var_domain, inferred = ci
8183
@unpack fullvars = ts
8284
@unpack graph = ts.structure
83-
continuous_id = 0
85+
continuous_id = Ref(0)
8486
clock_to_id = Dict{TimeDomain, Int}()
8587
id_to_clock = TimeDomain[]
8688
eq_to_cid = Vector{Int}(undef, nsrcs(graph))
8789
cid_to_eq = Vector{Int}[]
8890
var_to_cid = Vector{Int}(undef, ndsts(graph))
8991
cid_to_var = Vector{Int}[]
90-
cid = 0
92+
cid_counter = Ref(0)
9193
for (i, d) in enumerate(eq_domain)
92-
cid = get!(clock_to_id, d) do
93-
cid += 1
94-
push!(id_to_clock, d)
95-
if d isa Continuous
96-
continuous_id = cid
94+
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
95+
continuous_id = continuous_id
96+
97+
get!(clock_to_id, d) do
98+
cid = (cid_counter[] += 1)
99+
push!(id_to_clock, d)
100+
if d isa Continuous
101+
continuous_id[] = cid
102+
end
103+
cid
97104
end
98-
cid
99105
end
100106
eq_to_cid[i] = cid
101107
resize_or_push!(cid_to_eq, i, cid)
102108
end
103-
input_discrete = Int[]
104-
inputs = []
109+
input_idxs = map(_ -> Int[], 1:cid_counter[])
110+
inputs = map(_ -> Any[], 1:cid_counter[])
105111
for (i, d) in enumerate(var_domain)
106112
cid = get(clock_to_id, d, 0)
107113
@assert cid!==0 "Internal error!"
108114
var_to_cid[i] = cid
109115
v = fullvars[i]
110116
#TODO: remove Inferred*
111-
if cid == continuous_id && istree(v) && (o = operation(v)) isa Operator &&
112-
!(input_timedomain(o) isa Continuous)
113-
push!(input_discrete, i)
114-
push!(inputs, fullvars[i])
117+
if istree(v) && (o = operation(v)) isa Operator &&
118+
input_timedomain(o) != output_timedomain(o)
119+
push!(input_idxs[cid], i)
120+
push!(inputs[cid], fullvars[i])
115121
end
116122
resize_or_push!(cid_to_var, i, cid)
117123
end
@@ -123,20 +129,23 @@ function split_system(ci::ClockInference)
123129
ts_i = ts
124130
fadj = Vector{Int}[]
125131
eqs_i = Equation[]
132+
eq_to_diff = DiffGraph(length(ieqs))
126133
var_set_i = BitSet(vars)
127134
ne = 0
128-
for eq_i in ieqs
135+
for (j, eq_i) in enumerate(ieqs)
129136
vars = copy(graph.fadjlist[eq_i])
130137
ne += length(vars)
131138
push!(fadj, vars)
132139
push!(eqs_i, eqs[eq_i])
140+
eq_to_diff[j] = ts_i.structure.eq_to_diff[eq_i]
133141
end
134142
@set! ts_i.structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
135143
@set! ts_i.sys.eqs = eqs_i
144+
@set! ts_i.structure.eq_to_diff = eq_to_diff
136145
tss[id] = ts_i
137146
# TODO: just mark past and sample variables as inputs
138147
end
139-
return tss, (; inputs, outputs = ())
148+
return tss, inputs
140149

141150
#id_to_clock, cid_to_eq, cid_to_var
142151
end

src/systems/systemstructure.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function quick_cancel_expr(expr)
2525
kws...))(expr)
2626
end
2727

28-
export SystemStructure, TransformationState, TearingState
28+
export SystemStructure, TransformationState, TearingState, structural_simplify!
2929
export initialize_system_structure, find_linear_equations
3030
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs
3131
export dervars_range, diffvars_range, algvars_range
@@ -424,4 +424,19 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
424424
complete(ms.var_eq_matching, nsrcs(graph))))
425425
end
426426

427+
# TODO: clean up
428+
function structural_simplify!(state::TearingState, io = nothing; simplify = false,
429+
kwargs...)
430+
has_io = io !== nothing
431+
has_io && ModelingToolkit.markio!(state, io...)
432+
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
433+
sys, ag = ModelingToolkit.alias_elimination!(state; kwargs...)
434+
#check_consistency(state, ag)
435+
sys = ModelingToolkit.dummy_derivative(sys, state, ag; simplify)
436+
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
437+
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullstates)
438+
ModelingToolkit.invalidate_cache!(sys)
439+
return has_io ? (sys, input_idxs) : sys
440+
end
441+
427442
end # module

test/clock.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, Test
1+
using ModelingToolkit, Test, Setfield
22

33
function infer_clocks(sys)
44
ts = TearingState(sys)
@@ -14,6 +14,7 @@ D = Differential(t)
1414

1515
eqs = [yd ~ Sample(t, dt)(y)
1616
ud ~ kp * (r - yd)
17+
r ~ 1.0
1718

1819
# plant (time continuous part)
1920
u ~ Hold(ud)
@@ -61,19 +62,21 @@ By inference:
6162

6263
ci, varmap = infer_clocks(sys)
6364
eqmap = ci.eq_domain
64-
tss, io = ModelingToolkit.split_system(deepcopy(ci))
65-
ts_c = deepcopy(tss[1])
66-
@set! ts_c.structure.solvable_graph = nothing
67-
sss, = ModelingToolkit.structural_simplify!(ts_c, io)
65+
tss, inputs = ModelingToolkit.split_system(deepcopy(ci))
66+
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[1]), (inputs[1], ()))
6867
@test equations(sss) == [D(x) ~ u - x]
68+
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[2]), (inputs[2], ()))
69+
@test isempty(equations(sss))
70+
@test observed(sss) == [r ~ 1.0; yd ~ Sample(t, dt)(y); ud ~ kp * (r - yd)]
6971

7072
d = Clock(t, dt)
7173
# Note that TearingState reorders the equations
7274
@test eqmap[1] == Continuous()
7375
@test eqmap[2] == d
7476
@test eqmap[3] == d
75-
@test eqmap[4] == Continuous()
77+
@test eqmap[4] == d
7678
@test eqmap[5] == Continuous()
79+
@test eqmap[6] == Continuous()
7780

7881
@test varmap[yd] == d
7982
@test varmap[ud] == d

0 commit comments

Comments
 (0)