Skip to content

Commit 8ae5148

Browse files
feat: subset variables appropriately in clock inference
1 parent 4f7aac4 commit 8ae5148

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

src/systems/clock_inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ function split_system(ci::ClockInference{S}) where {S}
199199

200200
# breaks the system up into a continous and 0 or more discrete systems
201201
tss = similar(cid_to_eq, S)
202-
for (id, ieqs) in enumerate(cid_to_eq)
203-
ts_i = system_subset(ts, ieqs)
202+
for (id, (ieqs, ivars)) in enumerate(zip(cid_to_eq, cid_to_var))
203+
ts_i = system_subset(ts, ieqs, ivars)
204204
if id != continuous_id
205205
ts_i = shift_discrete_system(ts_i)
206206
@set! ts_i.structure.only_discrete = true

src/systems/systemstructure.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,11 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
219219
end
220220

221221
TransformationState(sys::AbstractSystem) = TearingState(sys)
222-
function system_subset(ts::TearingState, ieqs::Vector{Int})
222+
function system_subset(ts::TearingState, ieqs::Vector{Int}, ivars::Vector{Int})
223223
eqs = equations(ts)
224224
@set! ts.sys.eqs = eqs[ieqs]
225225
@set! ts.original_eqs = ts.original_eqs[ieqs]
226-
@set! ts.structure = system_subset(ts.structure, ieqs)
226+
@set! ts.structure = system_subset(ts.structure, ieqs, ivars)
227227
if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys))
228228
names = Symbol[]
229229
for eq in get_eqs(ts.sys)
@@ -240,22 +240,35 @@ function system_subset(ts::TearingState, ieqs::Vector{Int})
240240
else
241241
@set! ts.statemachines = eltype(ts.statemachines)[]
242242
end
243+
@set! ts.fullvars = ts.fullvars[ivars]
243244
ts
244245
end
245246

246-
function system_subset(structure::SystemStructure, ieqs::Vector{Int})
247-
@unpack graph, eq_to_diff = structure
247+
function system_subset(structure::SystemStructure, ieqs::Vector{Int}, ivars::Vector{Int})
248+
@unpack graph = structure
248249
fadj = Vector{Int}[]
249250
eq_to_diff = DiffGraph(length(ieqs))
251+
var_to_diff = DiffGraph(length(ivars))
252+
250253
ne = 0
254+
old_to_new_var = zeros(Int, ndsts(graph))
255+
for (i, iv) in enumerate(ivars)
256+
old_to_new_var[iv] = i
257+
end
258+
for (i, iv) in enumerate(ivars)
259+
structure.var_to_diff[iv] === nothing && continue
260+
var_to_diff[i] = old_to_new_var[structure.var_to_diff[iv]]
261+
end
251262
for (j, eq_i) in enumerate(ieqs)
252-
ivars = copy(graph.fadjlist[eq_i])
253-
ne += length(ivars)
254-
push!(fadj, ivars)
263+
var_adj = [old_to_new_var[i] for i in graph.fadjlist[eq_i]]
264+
@assert all(!iszero, var_adj)
265+
ne += length(var_adj)
266+
push!(fadj, var_adj)
255267
eq_to_diff[j] = structure.eq_to_diff[eq_i]
256268
end
257-
@set! structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
269+
@set! structure.graph = complete(BipartiteGraph(ne, fadj, length(ivars)))
258270
@set! structure.eq_to_diff = eq_to_diff
271+
@set! structure.var_to_diff = complete(var_to_diff)
259272
structure
260273
end
261274

0 commit comments

Comments
 (0)