Skip to content

Commit eb3b716

Browse files
authored
Merge pull request #2146 from SciML/myb/clock_generic
More generic clock split
2 parents 584dd0b + 8ea964e commit eb3b716

File tree

2 files changed

+34
-24
lines changed

2 files changed

+34
-24
lines changed

src/systems/clock_inference.jl

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,18 @@ end
6767
function resize_or_push!(v, val, idx)
6868
n = length(v)
6969
if idx > n
70-
for i in (n + 1):idx
70+
for _ in (n + 1):idx
7171
push!(v, Int[])
7272
end
7373
resize!(v, idx)
7474
end
7575
push!(v[idx], val)
7676
end
7777

78-
function split_system(ci::ClockInference)
78+
function split_system(ci::ClockInference{S}) where {S}
7979
@unpack ts, eq_domain, var_domain, inferred = ci
8080
fullvars = get_fullvars(ts)
81-
@unpack graph, var_to_diff = ts.structure
81+
@unpack graph = ts.structure
8282
continuous_id = Ref(0)
8383
clock_to_id = Dict{TimeDomain, Int}()
8484
id_to_clock = TimeDomain[]
@@ -121,26 +121,10 @@ function split_system(ci::ClockInference)
121121
resize_or_push!(cid_to_var, i, cid)
122122
end
123123

124-
eqs = equations(ts)
125-
tss = similar(cid_to_eq, TearingState)
124+
tss = similar(cid_to_eq, S)
126125
for (id, ieqs) in enumerate(cid_to_eq)
127-
vars = cid_to_var[id]
128-
ts_i = ts
129-
fadj = Vector{Int}[]
130-
eqs_i = Equation[]
131-
eq_to_diff = DiffGraph(length(ieqs))
132-
ne = 0
133-
for (j, eq_i) in enumerate(ieqs)
134-
vars = copy(graph.fadjlist[eq_i])
135-
ne += length(vars)
136-
push!(fadj, vars)
137-
push!(eqs_i, eqs[eq_i])
138-
eq_to_diff[j] = ts_i.structure.eq_to_diff[eq_i]
139-
end
140-
@set! ts_i.structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
126+
ts_i = system_subset(ts, ieqs)
141127
@set! ts_i.structure.only_discrete = id != continuous_id
142-
@set! ts_i.sys.eqs = eqs_i
143-
@set! ts_i.structure.eq_to_diff = eq_to_diff
144128
tss[id] = ts_i
145129
end
146130
return tss, inputs, continuous_id, id_to_clock
@@ -213,16 +197,20 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
213197
@unpack u, p, t = integrator
214198
c2d_obs = $cont_to_disc_obs
215199
d2c_obs = $disc_to_cont_obs
200+
# Like Sample
216201
c2d_view = view(p, $cont_to_disc_idxs)
202+
# Like Hold
217203
d2c_view = view(p, $disc_to_cont_idxs)
218204
disc_state = view(p, $disc_range)
219205
disc = $disc
220-
# Write continuous info to discrete
221-
# Write discrete info to continuous
206+
# Write continuous into to discrete: handles `Sample`
222207
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
208+
# Write discrete into to continuous
209+
# get old discrete states
223210
copyto!(d2c_view, d2c_obs(disc_state, p, t))
224211
push!(saved_values.t, t)
225212
push!(saved_values.saveval, $save_vec)
213+
# update discrete states
226214
$empty_disc || disc(disc_state, disc_state, p, t)
227215
end)
228216
sv = SavedValues(Float64, Vector{Float64})

src/systems/systemstructure.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export initialize_system_structure, find_linear_equations
3131
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs, is_only_discrete
3232
export dervars_range, diffvars_range, algvars_range
3333
export DiffGraph, complete!
34-
export get_fullvars
34+
export get_fullvars, system_subset
3535

3636
struct DiffGraph <: Graphs.AbstractGraph{Int}
3737
primal_to_diff::Vector{Union{Int, Nothing}}
@@ -205,6 +205,28 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
205205
end
206206

207207
TransformationState(sys::AbstractSystem) = TearingState(sys)
208+
function system_subset(ts::TearingState, ieqs::Vector{Int})
209+
eqs = equations(ts)
210+
@set! ts.sys.eqs = eqs[ieqs]
211+
@set! ts.structure = system_subset(ts.structure, ieqs)
212+
ts
213+
end
214+
215+
function system_subset(structure::SystemStructure, ieqs::Vector{Int})
216+
@unpack graph, eq_to_diff = structure
217+
fadj = Vector{Int}[]
218+
eq_to_diff = DiffGraph(length(ieqs))
219+
ne = 0
220+
for (j, eq_i) in enumerate(ieqs)
221+
ivars = copy(graph.fadjlist[eq_i])
222+
ne += length(ivars)
223+
push!(fadj, ivars)
224+
eq_to_diff[j] = structure.eq_to_diff[eq_i]
225+
end
226+
@set! structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
227+
@set! structure.eq_to_diff = eq_to_diff
228+
structure
229+
end
208230

209231
function Base.show(io::IO, state::TearingState)
210232
print(io, "TearingState of ", typeof(state.sys))

0 commit comments

Comments
 (0)