Skip to content

Commit 24d5041

Browse files
authored
Merge pull request #935 from SciML/myb/ss
Refactor & fix some minor bugs in alias elimination & add consistency check
2 parents 2dd2b61 + 79f37a3 commit 24d5041

23 files changed

+468
-295
lines changed

test/rc_model.jl renamed to examples/electrical_components.jl

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ function Pin(;name)
88
ODESystem(Equation[], t, [v, i], [], name=name, defaults=[v=>1.0, i=>1.0])
99
end
1010

11-
function Ground(name)
11+
function Ground(;name)
1212
@named g = Pin()
1313
eqs = [g.v ~ 0]
1414
ODESystem(eqs, t, [], [], systems=[g], name=name)
1515
end
1616

17-
function ConstantVoltage(name; V = 1.0)
17+
function ConstantVoltage(;name, V = 1.0)
1818
val = V
1919
@named p = Pin()
2020
@named n = Pin()
@@ -26,7 +26,7 @@ function ConstantVoltage(name; V = 1.0)
2626
ODESystem(eqs, t, [], [V], systems=[p, n], defaults=Dict(V => val), name=name)
2727
end
2828

29-
function Resistor(name; R = 1.0)
29+
function Resistor(;name, R = 1.0)
3030
val = R
3131
@named p = Pin()
3232
@named n = Pin()
@@ -40,7 +40,7 @@ function Resistor(name; R = 1.0)
4040
ODESystem(eqs, t, [v], [R], systems=[p, n], defaults=Dict(R => val), name=name)
4141
end
4242

43-
function Capacitor(name; C = 1.0)
43+
function Capacitor(;name, C = 1.0)
4444
val = C
4545
@named p = Pin()
4646
@named n = Pin()
@@ -55,15 +55,23 @@ function Capacitor(name; C = 1.0)
5555
ODESystem(eqs, t, [v], [C], systems=[p, n], defaults=Dict(C => val), name=name)
5656
end
5757

58-
R = 1.0
59-
C = 1.0
60-
V = 1.0
61-
resistor = Resistor(:resistor, R=R)
62-
capacitor = Capacitor(:capacitor, C=C)
63-
source = ConstantVoltage(:source, V=V)
64-
ground = Ground(:ground)
58+
function Inductor(; name, L = 1.0)
59+
val = L
60+
@named p = Pin()
61+
@named n = Pin()
62+
@variables v(t) i(t)
63+
@parameters L
64+
D = Differential(t)
65+
eqs = [
66+
v ~ p.v - n.v
67+
0 ~ p.i + n.i
68+
i ~ p.i
69+
D(i) ~ v / L
70+
]
71+
ODESystem(eqs, t, [v, i], [L], systems=[p, n], defaults=Dict(L => val), name=name)
72+
end
6573

66-
function connect(ps...)
74+
function connect_pins(ps...)
6775
eqs = [
6876
0 ~ sum(p->p.i, ps) # KCL
6977
]
@@ -74,10 +82,3 @@ function connect(ps...)
7482

7583
return eqs
7684
end
77-
rc_eqs = [
78-
connect(source.p, resistor.p)
79-
connect(resistor.n, capacitor.p)
80-
connect(capacitor.n, source.n, ground.g)
81-
]
82-
83-
rc_model = ODESystem(rc_eqs, t, systems=[resistor, capacitor, source, ground], name=:rc)

examples/rc_model.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
include("electrical_components.jl")
2+
3+
R = 1.0
4+
C = 1.0
5+
V = 1.0
6+
@named resistor = Resistor(R=R)
7+
@named capacitor = Capacitor(C=C)
8+
@named source = ConstantVoltage(V=V)
9+
@named ground = Ground()
10+
11+
rc_eqs = [
12+
connect_pins(source.p, resistor.p)
13+
connect_pins(resistor.n, capacitor.p)
14+
connect_pins(capacitor.n, source.n, ground.g)
15+
]
16+
17+
@named rc_model = ODESystem(rc_eqs, t, systems=[resistor, capacitor, source, ground])

examples/serial_inductor.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
include("electrical_components.jl")
2+
3+
@named source = ConstantVoltage(V=10.0)
4+
@named resistor = Resistor(R=1.0)
5+
@named inductor1 = Inductor(L=1.0e-2)
6+
@named inductor2 = Inductor(L=2.0e-2)
7+
@named ground = Ground()
8+
9+
eqs = [
10+
connect_pins(source.p, resistor.p)
11+
connect_pins(resistor.n, inductor1.p)
12+
connect_pins(inductor1.n, inductor2.p)
13+
connect_pins(source.n, inductor2.n, ground.g)
14+
]
15+
16+
@named ll_model = ODESystem(eqs, t, systems=[source, resistor, inductor1, inductor2, ground])

src/bipartite_graph.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@ badjlist = [[1,2,5,6],[3,4,6]]
6666
bg = BipartiteGraph(7, fadjlist, badjlist)
6767
```
6868
"""
69-
mutable struct BipartiteGraph{I<:Integer,M} <: LightGraphs.AbstractGraph{I}
69+
mutable struct BipartiteGraph{I<:Integer,F<:Vector{Vector{I}},B<:Union{Vector{Vector{I}},I},M} <: LightGraphs.AbstractGraph{I}
7070
ne::Int
71-
fadjlist::Vector{Vector{I}} # `fadjlist[src] => dsts`
72-
badjlist::Vector{Vector{I}} # `badjlist[dst] => srcs`
71+
fadjlist::F # `fadjlist[src] => dsts`
72+
badjlist::B # `badjlist[dst] => srcs` or `ndsts`
7373
metadata::M
7474
end
75-
BipartiteGraph(ne::Integer, fadj::AbstractVector, badj::AbstractVector) = BipartiteGraph(ne, fadj, badj, nothing)
75+
BipartiteGraph(ne::Integer, fadj::AbstractVector, badj::Union{AbstractVector,Integer}=maximum(maximum, fadj); metadata=nothing) = BipartiteGraph(ne, fadj, badj, metadata)
7676

7777
"""
7878
```julia
@@ -93,16 +93,16 @@ $(SIGNATURES)
9393
9494
Build an empty `BipartiteGraph` with `nsrcs` sources and `ndsts` destinations.
9595
"""
96-
function BipartiteGraph(nsrcs::T, ndsts::T; metadata=nothing) where T
96+
function BipartiteGraph(nsrcs::T, ndsts::T, backedge::Val{B}=Val(true); metadata=nothing) where {T,B}
9797
fadjlist = map(_->T[], 1:nsrcs)
98-
badjlist = map(_->T[], 1:ndsts)
98+
badjlist = B ? map(_->T[], 1:ndsts) : ndsts
9999
BipartiteGraph(0, fadjlist, badjlist, metadata)
100100
end
101101

102102
Base.eltype(::Type{<:BipartiteGraph{I}}) where I = I
103103
function Base.empty!(g::BipartiteGraph)
104104
foreach(empty!, g.fadjlist)
105-
foreach(empty!, g.badjlist)
105+
g.badjlist isa AbstractVector && foreach(empty!, g.badjlist)
106106
g.ne = 0
107107
if g.metadata !== nothing
108108
foreach(empty!, g.metadata)
@@ -111,17 +111,22 @@ function Base.empty!(g::BipartiteGraph)
111111
end
112112
Base.length(::BipartiteGraph) = error("length is not well defined! Use `ne` or `nv`.")
113113

114+
@noinline throw_no_back_edges() = throw(ArgumentError("The graph has no back edges."))
115+
114116
if isdefined(LightGraphs, :has_contiguous_vertices)
115117
LightGraphs.has_contiguous_vertices(::Type{<:BipartiteGraph}) = false
116118
end
117119
LightGraphs.is_directed(::Type{<:BipartiteGraph}) = false
118120
LightGraphs.vertices(g::BipartiteGraph) = (𝑠vertices(g), 𝑑vertices(g))
119121
𝑠vertices(g::BipartiteGraph) = axes(g.fadjlist, 1)
120-
𝑑vertices(g::BipartiteGraph) = axes(g.badjlist, 1)
122+
𝑑vertices(g::BipartiteGraph) = g.badjlist isa AbstractVector ? axes(g.badjlist, 1) : Base.OneTo(g.badjlist)
121123
has_𝑠vertex(g::BipartiteGraph, v::Integer) = v in 𝑠vertices(g)
122124
has_𝑑vertex(g::BipartiteGraph, v::Integer) = v in 𝑑vertices(g)
123125
𝑠neighbors(g::BipartiteGraph, i::Integer, with_metadata::Val{M}=Val(false)) where M = M ? zip(g.fadjlist[i], g.metadata[i]) : g.fadjlist[i]
124-
𝑑neighbors(g::BipartiteGraph, j::Integer, with_metadata::Val{M}=Val(false)) where M = M ? zip(g.badjlist[j], (g.metadata[i][j] for i in g.badjlist[j])) : g.badjlist[j]
126+
function 𝑑neighbors(g::BipartiteGraph, j::Integer, with_metadata::Val{M}=Val(false)) where M
127+
g.badjlist isa AbstractVector || throw_no_back_edges()
128+
M ? zip(g.badjlist[j], (g.metadata[i][j] for i in g.badjlist[j])) : g.badjlist[j]
129+
end
125130
LightGraphs.ne(g::BipartiteGraph) = g.ne
126131
LightGraphs.nv(g::BipartiteGraph) = sum(length, vertices(g))
127132
LightGraphs.edgetype(g::BipartiteGraph{I}) where I = BipartiteEdge{I}
@@ -145,7 +150,6 @@ const NO_METADATA = NoMetadata()
145150
LightGraphs.add_edge!(g::BipartiteGraph, i::Integer, j::Integer, md=NO_METADATA) = add_edge!(g, BipartiteEdge(i, j), md)
146151
function LightGraphs.add_edge!(g::BipartiteGraph, edge::BipartiteEdge, md=NO_METADATA)
147152
@unpack fadjlist, badjlist = g
148-
verts = vertices(g)
149153
s, d = src(edge), dst(edge)
150154
(has_𝑠vertex(g, s) && has_𝑑vertex(g, d)) || error("edge ($edge) out of range.")
151155
@inbounds list = fadjlist[s]
@@ -157,15 +161,21 @@ function LightGraphs.add_edge!(g::BipartiteGraph, edge::BipartiteEdge, md=NO_MET
157161
end
158162

159163
g.ne += 1
160-
@inbounds list = badjlist[d]
161-
index = searchsortedfirst(list, s)
162-
insert!(list, index, s)
164+
if badjlist isa AbstractVector
165+
@inbounds list = badjlist[d]
166+
index = searchsortedfirst(list, s)
167+
insert!(list, index, s)
168+
end
163169
return true # edge successfully added
164170
end
165171

166172
function LightGraphs.add_vertex!(g::BipartiteGraph{T}, type::VertType) where T
167173
if type === DST
168-
push!(g.badjlist, T[])
174+
if g.badjlist isa AbstractVector
175+
push!(g.badjlist, T[])
176+
else
177+
g.badjlist += 1
178+
end
169179
elseif type === SRC
170180
push!(g.fadjlist, T[])
171181
else

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using ModelingToolkit: ODESystem, var_from_nested_derivative, Differential,
1616
states, equations, vars, Symbolic, diff2term, value,
1717
operation, arguments, Sym, Term, simplify, solve_for,
1818
isdiffeq, isdifferential,
19-
get_structure, get_reduced_states, defaults
19+
get_structure, defaults, InvalidSystemException
2020

2121
using ModelingToolkit.BipartiteGraphs
2222
using ModelingToolkit.SystemStructures
@@ -31,7 +31,7 @@ using SparseArrays
3131

3232
using NonlinearSolve
3333

34-
export tearing, dae_index_lowering
34+
export tearing, dae_index_lowering, check_consistency
3535
export build_torn_function, build_observed_function, ODAEProblem
3636
export sorted_incidence_matrix
3737

src/structural_transformation/codegen.jl

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,27 @@ function torn_system_jacobian_sparsity(sys)
4040
# dependencies.
4141
avars2dvars = Dict{Int,Set{Int}}()
4242
c = 0
43-
for (_, _, teqs, tvars) in partitions
43+
for partition in partitions
44+
@unpack e_residual, v_residual = partition
4445
# initialization
45-
for tvar in tvars
46+
for tvar in v_residual
4647
avars2dvars[tvar] = Set{Int}()
4748
end
48-
for teq in teqs
49+
for teq in e_residual
4950
c += 1
5051
for var in 𝑠neighbors(graph, teq)
5152
# Skip the tearing variables in the current partition, because
5253
# we are computing them from all the other states.
53-
LightGraphs.insorted(var, tvars) && continue
54+
LightGraphs.insorted(var, v_residual) && continue
5455
deps = get(avars2dvars, var, nothing)
5556
if deps === nothing # differential variable
5657
@assert !isalgvar(s, var)
57-
for tvar in tvars
58+
for tvar in v_residual
5859
push!(avars2dvars[tvar], var)
5960
end
6061
else # tearing variable from previous partitions
6162
@assert isalgvar(s, var)
62-
for tvar in tvars
63+
for tvar in v_residual
6364
union!(avars2dvars[tvar], avars2dvars[var])
6465
end
6566
end
@@ -97,22 +98,22 @@ function partitions_dag(s::SystemStructure)
9798
@unpack partitions, graph = s
9899

99100
# `partvars[i]` contains all the states that appear in `partitions[i]`
100-
partvars = map(partitions) do (_, _, reqs, tvars)
101+
partvars = map(partitions) do partition
101102
ipartvars = Set{Int}()
102-
for req in reqs
103+
for req in partition.e_residual
103104
union!(ipartvars, 𝑠neighbors(graph, req))
104105
end
105106
ipartvars
106107
end
107108

108109
I, J = Int[], Int[]
109110
n = length(partitions)
110-
for i in 1:n
111+
for (i, partition) in enumerate(partitions)
111112
for j in i+1:n
112113
# The only way for a later partition `j` to depend on an earlier
113114
# partition `i` is when `partvars[j]` contains one of tearing
114115
# variables of partition `i`.
115-
if !isdisjoint(partvars[j], partitions[i][4])
116+
if !isdisjoint(partvars[j], partition.v_residual)
116117
# j depends on i
117118
push!(I, i)
118119
push!(J, j)
@@ -170,8 +171,8 @@ function get_torn_eqs_vars(sys)
170171
vars = s.fullvars
171172
eqs = equations(sys)
172173

173-
torn_eqs = map(idxs-> eqs[idxs], map(x->x[3], partitions))
174-
torn_vars = map(idxs->vars[idxs], map(x->x[4], partitions))
174+
torn_eqs = map(idxs-> eqs[idxs], map(x->x.e_residual, partitions))
175+
torn_vars = map(idxs->vars[idxs], map(x->x.v_residual, partitions))
175176

176177
gen_nlsolve.((sys,), torn_eqs, torn_vars)
177178
end
@@ -197,7 +198,7 @@ function build_torn_function(
197198
)
198199

199200
s = structure(sys)
200-
states = s.fullvars[diffvars_range(s)]
201+
states = map(i->s.fullvars[i], diffvars_range(s))
201202
syms = map(Symbol, states)
202203

203204
expr = SymbolicUtils.Code.toexpr(
@@ -243,9 +244,9 @@ given a set of `vars`, find the groups of equations we need to solve for
243244
to obtain the solution to `vars`
244245
"""
245246
function find_solve_sequence(partitions, vars)
246-
subset = filter(x -> !isdisjoint(x[4], vars), partitions)
247+
subset = filter(x -> !isdisjoint(x.v_residual, vars), partitions)
247248
isempty(subset) && return []
248-
vars′ = mapreduce(x->x[4], union, subset)
249+
vars′ = mapreduce(x->x.v_residual, union, subset)
249250
if vars′ == vars
250251
return subset
251252
else
@@ -267,8 +268,8 @@ function build_observed_function(
267268
syms_set = Set(syms)
268269
s = structure(sys)
269270
@unpack partitions, fullvars, graph = s
270-
diffvars = fullvars[diffvars_range(s)]
271-
algvars = fullvars[algvars_range(s)]
271+
diffvars = map(i->fullvars[i], diffvars_range(s))
272+
algvars = map(i->fullvars[i], algvars_range(s))
272273

273274
required_algvars = Set(intersect(algvars, syms_set))
274275
obs = observed(sys)
@@ -290,8 +291,8 @@ function build_observed_function(
290291
if !isempty(subset)
291292
eqs = equations(sys)
292293

293-
torn_eqs = map(idxs-> eqs[idxs[3]], subset)
294-
torn_vars = map(idxs->fullvars[idxs[4]], subset)
294+
torn_eqs = map(idxs-> eqs[idxs.e_residual], subset)
295+
torn_vars = map(idxs->fullvars[idxs.v_residual], subset)
295296

296297
solves = gen_nlsolve.((sys,), torn_eqs, torn_vars; checkbounds=checkbounds)
297298
else
@@ -338,7 +339,7 @@ function ODAEProblem{iip}(
338339
) where {iip}
339340
s = structure(sys)
340341
@unpack fullvars = s
341-
dvs = fullvars[diffvars_range(s)]
342+
dvs = map(i->fullvars[i], diffvars_range(s))
342343
ps = parameters(sys)
343344
defs = defaults(sys)
344345

0 commit comments

Comments
 (0)