Skip to content

Commit ac01392

Browse files
committed
Clean up graph initialization code
1 parent 2e559ac commit ac01392

File tree

4 files changed

+60
-51
lines changed

4 files changed

+60
-51
lines changed

src/bipartite_graph.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ badjlist = [[1,2,5,6],[3,4,6]]
6666
bg = BipartiteGraph(7, fadjlist, badjlist)
6767
```
6868
"""
69-
mutable struct BipartiteGraph{I<:Integer} <: LightGraphs.AbstractGraph{I}
69+
mutable struct BipartiteGraph{I<:Integer,M} <: LightGraphs.AbstractGraph{I}
7070
ne::Int
7171
fadjlist::Vector{Vector{I}} # `fadjlist[src] => dsts`
7272
badjlist::Vector{Vector{I}} # `badjlist[dst] => srcs`
73+
metadata::M
7374
end
75+
BipartiteGraph(ne::Integer, fadj::AbstractVector, badj::AbstractVector) = BipartiteGraph(ne, fadj, badj, nothing)
7476

7577
"""
7678
```julia
@@ -91,9 +93,13 @@ $(SIGNATURES)
9193
9294
Build an empty `BipartiteGraph` with `nsrcs` sources and `ndsts` destinations.
9395
"""
94-
BipartiteGraph(nsrcs::T, ndsts::T) where T = BipartiteGraph(0, map(_->T[], 1:nsrcs), map(_->T[], 1:ndsts))
96+
function BipartiteGraph(nsrcs::T, ndsts::T; metadata=nothing) where T
97+
fadjlist = map(_->T[], 1:nsrcs)
98+
badjlist = map(_->T[], 1:ndsts)
99+
BipartiteGraph(0, fadjlist, badjlist, metadata)
100+
end
95101

96-
Base.eltype(::Type{BipartiteGraph{I}}) where I = I
102+
Base.eltype(::Type{<:BipartiteGraph{I}}) where I = I
97103
Base.empty!(g::BipartiteGraph) = (foreach(empty!, g.fadjlist); foreach(empty!, g.badjlist); g.ne = 0; g)
98104
Base.length(::BipartiteGraph) = error("length is not well defined! Use `ne` or `nv`.")
99105

@@ -106,8 +112,8 @@ LightGraphs.vertices(g::BipartiteGraph) = (𝑠vertices(g), 𝑑vertices(g))
106112
𝑑vertices(g::BipartiteGraph) = axes(g.badjlist, 1)
107113
has_𝑠vertex(g::BipartiteGraph, v::Integer) = v in 𝑠vertices(g)
108114
has_𝑑vertex(g::BipartiteGraph, v::Integer) = v in 𝑑vertices(g)
109-
𝑠neighbors(g::BipartiteGraph, i::Integer) = g.fadjlist[i]
110-
𝑑neighbors(g::BipartiteGraph, i::Integer) = g.badjlist[i]
115+
𝑠neighbors(g::BipartiteGraph, i::Integer, with_metadata::Val{M}=Val(false)) where M = M ? zip(g.fadjlist[i], g.metadata[i]) : g.fadjlist[i]
116+
𝑑neighbors(g::BipartiteGraph, j::Integer, with_metadata::Val{M}=Val(false)) where M = M ? zip(g.badjlist[j], (g.metadata[i][j] for i in g.badjlist[j])) : g.badjlist[j]
111117
LightGraphs.ne(g::BipartiteGraph) = g.ne
112118
LightGraphs.nv(g::BipartiteGraph) = sum(length, vertices(g))
113119
LightGraphs.edgetype(g::BipartiteGraph{I}) where I = BipartiteEdge{I}
@@ -170,7 +176,7 @@ Base.length(it::BipartiteEdgeIter{ALL}) = 2ne(it.g)
170176

171177
Base.eltype(it::BipartiteEdgeIter) = edgetype(it.g)
172178

173-
function Base.iterate(it::BipartiteEdgeIter{SRC,BipartiteGraph{T}}, state=(1, 1, SRC)) where T
179+
function Base.iterate(it::BipartiteEdgeIter{SRC,<:BipartiteGraph{T}}, state=(1, 1, SRC)) where T
174180
@unpack g = it
175181
neqs = nsrcs(g)
176182
neqs == 0 && return nothing
@@ -191,7 +197,7 @@ function Base.iterate(it::BipartiteEdgeIter{SRC,BipartiteGraph{T}}, state=(1, 1,
191197
return nothing
192198
end
193199

194-
function Base.iterate(it::BipartiteEdgeIter{DST,BipartiteGraph{T}}, state=(1, 1, DST)) where T
200+
function Base.iterate(it::BipartiteEdgeIter{DST,<:BipartiteGraph{T}}, state=(1, 1, DST)) where T
195201
@unpack g = it
196202
nvars = ndsts(g)
197203
nvars == 0 && return nothing

src/systems/alias_elimination.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
export alias_elimination, flatten
2-
31
using SymbolicUtils: Rewriters
42

53
function fixpoint_sub(x, dict)

src/systems/diffeqs/odesystem.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,31 @@ iv_from_nested_derivative(x) = missing
111111
vars(x::Sym) = [x]
112112
vars(exprs::Symbolic) = vars([exprs])
113113
vars(exprs) = foldl(vars!, exprs; init = Set())
114+
vars!(vars, eq::Equation) = (vars!(vars, eq.lhs); vars!(vars, eq.rhs); vars)
114115
function vars!(vars, O)
115116
isa(O, Sym) && return push!(vars, O)
116-
!isa(O, Symbolic) && return vars
117+
!istree(O) && return vars
118+
119+
operation(O) isa Differential && return push!(vars, O)
117120

118121
operation(O) isa Sym && push!(vars, O)
119-
for arg arguments(O)
122+
for arg in arguments(O)
120123
vars!(vars, arg)
121124
end
122125

123126
return vars
124127
end
125128

129+
find_derivatives!(vars, expr::Equation, f=identity) = (find_derivatives!(vars, expr.lhs, f); find_derivatives!(vars, expr.rhs, f); vars)
130+
function find_derivatives!(vars, expr, f)
131+
!istree(O) && return vars
132+
operation(O) isa Differential && push!(vars, f(O))
133+
for arg in arguments(O)
134+
vars!(vars, arg)
135+
end
136+
return vars
137+
end
138+
126139
function ODESystem(eqs, iv=nothing; kwargs...)
127140
# NOTE: this assumes that the order of algebric equations doesn't matter
128141
diffvars = OrderedSet()

src/systems/systemstructure.jl

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module SystemStructures
22

3+
using DataStructures
4+
using SymbolicUtils: istree, operation
35
using ..ModelingToolkit
46
import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten
57
using SymbolicUtils: arguments
@@ -47,8 +49,8 @@ struct SystemStructure
4749
fullvars::Vector # [xvar; dxvars; algvars]
4850
varassoc::Vector{Int}
4951
algeqs::BitVector
50-
graph::BipartiteGraph{Int}
51-
solvable_graph::BipartiteGraph{Int}
52+
graph::BipartiteGraph{Int,Nothing}
53+
solvable_graph::BipartiteGraph{Int,Vector{Vector{Int}}}
5254
assign::Vector{Int}
5355
inv_assign::Vector{Int}
5456
scc::Vector{Vector{Int}}
@@ -108,55 +110,45 @@ function Base.show(io::IO, s::SystemStructure)
108110
show(io, S)
109111
end
110112

111-
# V-nodes `[x_1, x_2, x_3, ..., dx_1, dx_2, ..., y_1, y_2, ...]` where `x`s are
112-
# differential variables and `y`s are algebraic variables.
113-
function collect_variables(sys)
114-
dxvars = []
113+
function init_graph(sys)
114+
iv = independent_variable(sys)
115115
eqs = equations(sys)
116-
algeqs = trues(length(eqs))
117-
for (i, eq) in enumerate(eqs)
118-
if isdiffeq(eq)
119-
algeqs[i] = false
120-
lhs = eq.lhs
121-
# Make sure that the LHS is a first order derivative of a var.
122-
@assert !(arguments(lhs)[1] isa Differential) "The equation $eq is not first order"
116+
neqs = length(eqs)
117+
algeqs = trues(neqs)
118+
varsadj = Vector{Any}(undef, neqs)
119+
dervars = OrderedSet()
120+
diffvars = OrderedSet()
123121

124-
push!(dxvars, lhs)
122+
for (i, eq) in enumerate(eqs)
123+
vars = OrderedSet()
124+
vars!(vars, eq)
125+
for var in vars
126+
if istree(var) && operation(var) isa Differential
127+
diffvar = arguments(var)[1]
128+
@assert !(diffvar isa Differential) "The equation [ $eq ] is not first order"
129+
push!(dervars, var)
130+
push!(diffvars, diffvar)
131+
end
125132
end
133+
varsadj[i] = vars
126134
end
127135

128-
xvars = (first var_from_nested_derivative).(dxvars)
129-
algvars = setdiff(states(sys), xvars)
130-
return xvars, dxvars, algvars, algeqs
131-
end
136+
algvars = setdiff(states(sys), diffvars)
137+
fullvars = [collect(diffvars); collect(dervars); algvars]
132138

133-
function init_graph(sys)
134-
xvars, dxvars, algvars, algeqs = collect_variables(sys)
135-
dxvar_offset = length(xvars)
139+
dxvar_offset = length(diffvars)
136140
algvar_offset = 2dxvar_offset
137141

138-
fullvars = [xvars; dxvars; algvars]
139-
eqs = equations(sys)
140-
idxmap = Dict(fullvars .=> 1:length(fullvars))
141-
graph = BipartiteGraph(length(eqs), length(fullvars))
142-
solvable_graph = BipartiteGraph(length(eqs), length(fullvars))
142+
nvars = length(fullvars)
143+
idxmap = Dict(fullvars .=> 1:nvars)
144+
graph = BipartiteGraph(neqs, nvars)
145+
solvable_graph = BipartiteGraph(neqs, nvars, metadata=Vector{Int}[])
143146

144-
vs = Set()
145-
for (i, eq) in enumerate(eqs)
146-
# TODO: custom vars that handles D(x)
147-
# TODO: add checks here
148-
lhs = eq.lhs
149-
if isdiffeq(eq)
150-
v = lhs
151-
haskey(idxmap, v) && add_edge!(graph, i, idxmap[v])
152-
else
153-
vars!(vs, lhs)
154-
end
155-
vars!(vs, eq.rhs)
147+
for (i, vs) in enumerate(varsadj)
156148
for v in vs
157-
haskey(idxmap, v) && add_edge!(graph, i, idxmap[v])
149+
j = get(idxmap, v, nothing)
150+
j === nothing || add_edge!(graph, i, idxmap[v])
158151
end
159-
empty!(vs)
160152
end
161153

162154
varassoc = Int[(1:dxvar_offset) .+ dxvar_offset; zeros(Int, length(fullvars) - dxvar_offset)] # variable association list

0 commit comments

Comments
 (0)