Skip to content

Commit 8264144

Browse files
committed
Give partition names
1 parent 3854947 commit 8264144

File tree

7 files changed

+59
-43
lines changed

7 files changed

+59
-43
lines changed

examples/electrical_components.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using Test
22
using ModelingToolkit, OrdinaryDiffEq
33

44
# Basic electric components
5-
const t = Sym{ModelingToolkit.Parameter{Real}}(:t)
5+
#const t = Sym{ModelingToolkit.Parameter{Real}}(:t)
6+
@parameters t
67
function Pin(;name)
78
@variables v(t) i(t)
89
ODESystem(Equation[], t, [v, i], [], name=name, defaults=[v=>1.0, i=>1.0])

src/structural_transformation/codegen.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,27 @@ function torn_system_jacobian_sparsity(sys)
4040
# dependencies.
4141
avars2dvars = Dict{Int,Set{Int}}()
4242
c = 0
43-
for (_, _, teqs, tvars) in partitions
43+
for partition in partitions
44+
@unpack e_residual, v_residual = partition
4445
# initialization
45-
for tvar in tvars
46+
for tvar in v_residual
4647
avars2dvars[tvar] = Set{Int}()
4748
end
48-
for teq in teqs
49+
for teq in e_residual
4950
c += 1
5051
for var in 𝑠neighbors(graph, teq)
5152
# Skip the tearing variables in the current partition, because
5253
# we are computing them from all the other states.
53-
LightGraphs.insorted(var, tvars) && continue
54+
LightGraphs.insorted(var, v_residual) && continue
5455
deps = get(avars2dvars, var, nothing)
5556
if deps === nothing # differential variable
5657
@assert !isalgvar(s, var)
57-
for tvar in tvars
58+
for tvar in v_residual
5859
push!(avars2dvars[tvar], var)
5960
end
6061
else # tearing variable from previous partitions
6162
@assert isalgvar(s, var)
62-
for tvar in tvars
63+
for tvar in v_residual
6364
union!(avars2dvars[tvar], avars2dvars[var])
6465
end
6566
end
@@ -97,22 +98,22 @@ function partitions_dag(s::SystemStructure)
9798
@unpack partitions, graph = s
9899

99100
# `partvars[i]` contains all the states that appear in `partitions[i]`
100-
partvars = map(partitions) do (_, _, reqs, tvars)
101+
partvars = map(partitions) do partition
101102
ipartvars = Set{Int}()
102-
for req in reqs
103+
for req in partition.e_residual
103104
union!(ipartvars, 𝑠neighbors(graph, req))
104105
end
105106
ipartvars
106107
end
107108

108109
I, J = Int[], Int[]
109110
n = length(partitions)
110-
for i in 1:n
111+
for (i, partition) in enumerate(partitions)
111112
for j in i+1:n
112113
# The only way for a later partition `j` to depend on an earlier
113114
# partition `i` is when `partvars[j]` contains one of tearing
114115
# variables of partition `i`.
115-
if !isdisjoint(partvars[j], partitions[i][4])
116+
if !isdisjoint(partvars[j], partition.v_residual)
116117
# j depends on i
117118
push!(I, i)
118119
push!(J, j)
@@ -170,8 +171,8 @@ function get_torn_eqs_vars(sys)
170171
vars = s.fullvars
171172
eqs = equations(sys)
172173

173-
torn_eqs = map(idxs-> eqs[idxs], map(x->x[3], partitions))
174-
torn_vars = map(idxs->vars[idxs], map(x->x[4], partitions))
174+
torn_eqs = map(idxs-> eqs[idxs], map(x->x.e_residual, partitions))
175+
torn_vars = map(idxs->vars[idxs], map(x->x.v_residual, partitions))
175176

176177
gen_nlsolve.((sys,), torn_eqs, torn_vars)
177178
end
@@ -243,9 +244,9 @@ given a set of `vars`, find the groups of equations we need to solve for
243244
to obtain the solution to `vars`
244245
"""
245246
function find_solve_sequence(partitions, vars)
246-
subset = filter(x -> !isdisjoint(x[4], vars), partitions)
247+
subset = filter(x -> !isdisjoint(x.v_residual, vars), partitions)
247248
isempty(subset) && return []
248-
vars′ = mapreduce(x->x[4], union, subset)
249+
vars′ = mapreduce(x->x.v_residual, union, subset)
249250
if vars′ == vars
250251
return subset
251252
else
@@ -289,8 +290,8 @@ function build_observed_function(
289290
if !isempty(subset)
290291
eqs = equations(sys)
291292

292-
torn_eqs = map(idxs-> eqs[idxs[3]], subset)
293-
torn_vars = map(idxs->fullvars[idxs[4]], subset)
293+
torn_eqs = map(idxs-> eqs[idxs.e_residual], subset)
294+
torn_vars = map(idxs->fullvars[idxs.v_residual], subset)
294295

295296
solves = gen_nlsolve.((sys,), torn_eqs, torn_vars)
296297
else

src/structural_transformation/tearing.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@ function tear_graph(sys)
88
s = structure(sys)
99
@unpack graph, solvable_graph, assign, inv_assign, scc = s
1010

11-
partitions = map(scc) do c
11+
@set! sys.structure.partitions = map(scc) do c
1212
ieqs = filter(eq->isalgeq(s, eq), c)
1313
vars = inv_assign[ieqs]
1414

1515
td = TraverseDAG(graph.fadjlist, length(assign))
16-
e_solved, v_solved, e_residue, v_tear = tearEquations!(td, solvable_graph.fadjlist, ieqs, vars)
16+
SystemPartition(tearEquations!(td, solvable_graph.fadjlist, ieqs, vars)...)
1717
end
18-
19-
@set! sys.structure.partitions = partitions
2018
return sys
2119
end
2220

@@ -37,7 +35,8 @@ function tearing_reassemble(sys; simplify=false)
3735
active_eqs = trues(ns)
3836
active_vars = trues(nd)
3937
rvar2req = Vector{Int}(undef, nd)
40-
for (ith_scc, (e_solved, v_solved, e_residue, v_tear)) in enumerate(partitions)
38+
for (ith_scc, partition) in enumerate(partitions)
39+
@unpack e_solved, v_solved, e_residual, v_residual = partition
4140
for ii in eachindex(e_solved)
4241
ieq = e_solved[ii]; ns -= 1
4342
iv = v_solved[ii]; nd -= 1
@@ -161,18 +160,19 @@ function tearing_reassemble(sys; simplify=false)
161160
### update partitions
162161
newpartitions = similar(partitions, 0)
163162
emptyintvec = Int[]
164-
for ii in eachindex(partitions)
165-
_, _, og_e_residue, og_v_tear = partitions[ii]
166-
isempty(og_v_tear) && continue
167-
e_residue = similar(og_e_residue)
168-
v_tear = similar(og_v_tear)
169-
for ii in eachindex(og_e_residue)
170-
e_residue[ii] = eq_reidx[og_e_residue[ii]]
171-
v_tear[ii] = var_reidx[og_v_tear[ii]]
163+
for (ii, partition) in enumerate(partitions)
164+
@unpack e_residual, v_residual = partition
165+
isempty(v_residual) && continue
166+
new_e_residual = similar(e_residual)
167+
new_v_residual = similar(v_residual)
168+
for ii in eachindex(e_residual)
169+
new_e_residual[ii] = eq_reidx[ e_residual[ii]]
170+
new_v_residual[ii] = var_reidx[v_residual[ii]]
172171
end
173172
# `emptyintvec` is aliased to save memory
174173
# We need them for type stability
175-
push!(newpartitions, (emptyintvec, emptyintvec, e_residue, v_tear))
174+
newpart = SystemPartition(emptyintvec, emptyintvec, new_e_residual, new_v_residual)
175+
push!(newpartitions, newpart)
176176
end
177177

178178
obseqs = solvars .~ rhss

src/structural_transformation/utils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,21 +262,21 @@ function reordered_matrix(sys, partitions=structure(sys).partitions)
262262
I, J = Int[], Int[]
263263
ii = 0
264264
M = Int[]
265-
for (e_solved, v_solved, e_residue, v_tear) in partitions
266-
append!(M, v_solved)
267-
append!(M, v_tear)
265+
for partition in partitions
266+
append!(M, partition.v_solved)
267+
append!(M, partition.v_residual)
268268
end
269269
M = inverse_mapping(vcat(M, setdiff(1:nvars, M)))
270-
for (e_solved, v_solved, e_residue, v_tear) in partitions
271-
for es in e_solved
270+
for partition in partitions
271+
for es in partition.e_solved
272272
isdiffeq(eqs[es]) && continue
273273
ii += 1
274274
js = [M[x] for x in 𝑠neighbors(graph, es) if isalgvar(s, x)]
275275
append!(I, fill(ii, length(js)))
276276
append!(J, js)
277277
end
278278

279-
for er in e_residue
279+
for er in partition.e_residual
280280
isdiffeq(eqs[er]) && continue
281281
ii += 1
282282
js = [M[x] for x in 𝑠neighbors(graph, er) if isalgvar(s, x)]

src/systems/abstractsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ function Base.getproperty(sys::AbstractSystem, name::Symbol; namespace=true)
244244
end
245245
end
246246

247-
throw(error("Variable $name does not exist"))
247+
throw(ArgumentError("Variable $name does not exist"))
248248
end
249249

250250
function Base.setproperty!(sys::AbstractSystem, prop::Symbol, val)
@@ -525,7 +525,7 @@ function structural_simplify(sys::AbstractSystem)
525525
sys = initialize_system_structure(alias_elimination(sys))
526526
check_consistency(structure(sys))
527527
if sys isa ODESystem
528-
sys = sort_states(dae_index_lowering(sys))
528+
sys = dae_index_lowering(sys)
529529
end
530530
sys = tearing(sys)
531531
fullstates = [map(eq->eq.lhs, observed(sys)); states(sys)]

src/systems/systemstructure.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,26 @@ for v in 𝑣vertices(graph); active_𝑣vertices[v] || continue
3737
end
3838
=#
3939

40-
export SystemStructure, initialize_system_structure, find_linear_equations
40+
export SystemStructure, SystemPartition
41+
export initialize_system_structure, find_linear_equations
4142
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq
4243
export dervars_range, diffvars_range, algvars_range
4344

4445
@enum VariableType::Int8 DIFFERENTIAL_VARIABLE ALGEBRAIC_VARIABLE DERIVATIVE_VARIABLE
4546

47+
Base.@kwdef struct SystemPartition
48+
e_solved::Vector{Int}
49+
v_solved::Vector{Int}
50+
e_residual::Vector{Int}
51+
v_residual::Vector{Int}
52+
end
53+
54+
function Base.:(==)(s1::SystemPartition, s2::SystemPartition)
55+
tup1 = (s1.e_solved, s1.v_solved, s1.e_residual, s1.v_residual)
56+
tup2 = (s2.e_solved, s2.v_solved, s2.e_residual, s2.v_residual)
57+
tup1 == tup2
58+
end
59+
4660
Base.@kwdef struct SystemStructure
4761
fullvars::Vector
4862
vartype::Vector{VariableType}
@@ -55,7 +69,7 @@ Base.@kwdef struct SystemStructure
5569
assign::Vector{Int}
5670
inv_assign::Vector{Int}
5771
scc::Vector{Vector{Int}}
58-
partitions::Vector{NTuple{4, Vector{Int}}}
72+
partitions::Vector{SystemPartition}
5973
end
6074

6175
isdervar(s::SystemStructure, var::Integer) = s.vartype[var] === DERIVATIVE_VARIABLE
@@ -168,7 +182,7 @@ function initialize_system_structure(sys)
168182
assign = Int[],
169183
inv_assign = Int[],
170184
scc = Vector{Int}[],
171-
partitions = NTuple{4, Vector{Int}}[],
185+
partitions = SystemPartition[],
172186
)
173187
return sys
174188
end

test/structural_transformation/tearing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ tornsys = tearing(sys)
5353
sss = structure(tornsys)
5454
@unpack graph, solvable_graph, assign, partitions = sss
5555
@test graph.fadjlist == [[1]]
56-
@test partitions == [([], [], [1], [1])]
56+
@test partitions == [StructuralTransformations.SystemPartition([], [], [1], [1])]
5757

5858
# Before:
5959
# u1 u2 u3 u4 u5

0 commit comments

Comments
 (0)