Skip to content

Commit ca2a2d3

Browse files
committed
More generic clock split
1 parent 4c20852 commit ca2a2d3

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

src/systems/clock_inference.jl

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,18 @@ end
6767
function resize_or_push!(v, val, idx)
6868
n = length(v)
6969
if idx > n
70-
for i in (n + 1):idx
70+
for _ in (n + 1):idx
7171
push!(v, Int[])
7272
end
7373
resize!(v, idx)
7474
end
7575
push!(v[idx], val)
7676
end
7777

78-
function split_system(ci::ClockInference)
78+
function split_system(ci::ClockInference{S}) where {S}
7979
@unpack ts, eq_domain, var_domain, inferred = ci
8080
fullvars = get_fullvars(ts)
81-
@unpack graph, var_to_diff = ts.structure
81+
@unpack graph = ts.structure
8282
continuous_id = Ref(0)
8383
clock_to_id = Dict{TimeDomain, Int}()
8484
id_to_clock = TimeDomain[]
@@ -121,26 +121,10 @@ function split_system(ci::ClockInference)
121121
resize_or_push!(cid_to_var, i, cid)
122122
end
123123

124-
eqs = equations(ts)
125-
tss = similar(cid_to_eq, TearingState)
124+
tss = similar(cid_to_eq, S)
126125
for (id, ieqs) in enumerate(cid_to_eq)
127-
vars = cid_to_var[id]
128-
ts_i = ts
129-
fadj = Vector{Int}[]
130-
eqs_i = Equation[]
131-
eq_to_diff = DiffGraph(length(ieqs))
132-
ne = 0
133-
for (j, eq_i) in enumerate(ieqs)
134-
vars = copy(graph.fadjlist[eq_i])
135-
ne += length(vars)
136-
push!(fadj, vars)
137-
push!(eqs_i, eqs[eq_i])
138-
eq_to_diff[j] = ts_i.structure.eq_to_diff[eq_i]
139-
end
140-
@set! ts_i.structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
126+
ts_i = system_subset(ts, ieqs)
141127
@set! ts_i.structure.only_discrete = id != continuous_id
142-
@set! ts_i.sys.eqs = eqs_i
143-
@set! ts_i.structure.eq_to_diff = eq_to_diff
144128
tss[id] = ts_i
145129
end
146130
return tss, inputs, continuous_id, id_to_clock

src/systems/systemstructure.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export initialize_system_structure, find_linear_equations
3131
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs, is_only_discrete
3232
export dervars_range, diffvars_range, algvars_range
3333
export DiffGraph, complete!
34-
export get_fullvars
34+
export get_fullvars, system_subset
3535

3636
struct DiffGraph <: Graphs.AbstractGraph{Int}
3737
primal_to_diff::Vector{Union{Int, Nothing}}
@@ -205,6 +205,28 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
205205
end
206206

207207
TransformationState(sys::AbstractSystem) = TearingState(sys)
208+
function system_subset(ts::TearingState, ieqs::Vector{Int})
209+
eqs = equations(ts)
210+
@set! ts.sys.eqs = eqs[ieqs]
211+
@set! ts.structure = system_subset(ts.structure, ieqs)
212+
ts
213+
end
214+
215+
function system_subset(structure::SystemStructure, ieqs::Vector{Int})
216+
@unpack graph, eq_to_diff = structure
217+
fadj = Vector{Int}[]
218+
eq_to_diff = DiffGraph(length(ieqs))
219+
ne = 0
220+
for (j, eq_i) in enumerate(ieqs)
221+
ivars = copy(graph.fadjlist[eq_i])
222+
ne += length(ivars)
223+
push!(fadj, ivars)
224+
eq_to_diff[j] = structure.eq_to_diff[eq_i]
225+
end
226+
@set! structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
227+
@set! structure.eq_to_diff = eq_to_diff
228+
structure
229+
end
208230

209231
function Base.show(io::IO, state::TearingState)
210232
print(io, "TearingState of ", typeof(state.sys))

0 commit comments

Comments
 (0)