Skip to content

Commit ea7207d

Browse files
authored
Merge pull request #738 from SciML/myb/bg
Add more methods on BipartiteGraph and add SystemStructure
2 parents 79b07a1 + ecdcdb1 commit ea7207d

File tree

10 files changed

+418
-75
lines changed

10 files changed

+418
-75
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2424
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2525
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
2626
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
27+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2728
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2829
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2930
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -50,9 +51,10 @@ RecursiveArrayTools = "2.3"
5051
Requires = "1.0"
5152
RuntimeGeneratedFunctions = "0.4, 0.5"
5253
SafeTestsets = "0.0.1"
54+
Setfield = "0.7"
5355
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0"
5456
StaticArrays = "0.10, 0.11, 0.12, 1.0"
55-
SymbolicUtils = "0.7.3"
57+
SymbolicUtils = "0.7.4"
5658
TreeViews = "0.3"
5759
UnPack = "0.1, 1.0"
5860
Unitful = "1.1"

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ Get the set of parameters variables for the given system.
168168
"""
169169
function parameters end
170170

171+
include("bipartite_graph.jl")
172+
171173
include("variables.jl")
172174
include("context_dsl.jl")
173175
include("differentials.jl")
@@ -181,6 +183,7 @@ include("domains.jl")
181183
include("register_function.jl")
182184

183185
include("systems/abstractsystem.jl")
186+
include("systems/systemstructure.jl")
184187

185188
include("systems/diffeqs/odesystem.jl")
186189
include("systems/diffeqs/sdesystem.jl")
@@ -211,6 +214,7 @@ include("extra_functions.jl")
211214

212215
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr
213216
export SDESystem, SDEFunction, SDEFunctionExpr, SDESystemExpr
217+
export SystemStructure
214218
export JumpSystem
215219
export ODEProblem, SDEProblem
216220
export NonlinearProblem, NonlinearProblemExpr

src/bipartite_graph.jl

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
using UnPack
2+
using SparseArrays
3+
using LightGraphs
4+
using Setfield
5+
6+
###
7+
### Edges & Vertex
8+
###
9+
@enum VertType SRC DST ALL
10+
11+
struct BipartiteEdge{I<:Integer} <: LightGraphs.AbstractEdge{I}
12+
src::I
13+
dst::I
14+
function BipartiteEdge(src::I, dst::V) where {I,V}
15+
T = promote_type(I, V)
16+
new{T}(T(src), T(dst))
17+
end
18+
end
19+
20+
LightGraphs.src(edge::BipartiteEdge) = edge.src
21+
LightGraphs.dst(edge::BipartiteEdge) = edge.dst
22+
23+
function Base.show(io::IO, edge::BipartiteEdge)
24+
@unpack src, dst = edge
25+
print(io, "[src: ", src, "] => [dst: ", dst, "]")
26+
end
27+
28+
Base.:(==)(a::BipartiteEdge, b::BipartiteEdge) = src(a) == src(b) && dst(a) == dst(b)
29+
30+
###
31+
### Graph
32+
###
33+
"""
34+
$(TYPEDEF)
35+
36+
A bipartite graph representation between two, possibly distinct, sets of vertices
37+
(source and dependencies). Maps source vertices, labelled `1:N₁`, to vertices
38+
on which they depend (labelled `1:N₂`).
39+
40+
# Fields
41+
$(FIELDS)
42+
43+
# Example
44+
```julia
45+
using ModelingToolkit
46+
47+
ne = 4
48+
srcverts = 1:4
49+
depverts = 1:2
50+
51+
# six source vertices
52+
fadjlist = [[1],[1],[2],[2],[1],[1,2]]
53+
54+
# two vertices they depend on
55+
badjlist = [[1,2,5,6],[3,4,6]]
56+
57+
bg = BipartiteGraph(7, fadjlist, badjlist)
58+
```
59+
"""
60+
mutable struct BipartiteGraph{I<:Integer} <: LightGraphs.AbstractGraph{I}
61+
ne::Int
62+
fadjlist::Vector{Vector{I}} # `fadjlist[src] => dsts`
63+
badjlist::Vector{Vector{I}} # `badjlist[dst] => srcs`
64+
end
65+
66+
"""
67+
```julia
68+
Base.isequal(bg1::BipartiteGraph{T}, bg2::BipartiteGraph{T}) where {T<:Integer}
69+
```
70+
71+
Test whether two [`BipartiteGraph`](@ref)s are equal.
72+
"""
73+
function Base.isequal(bg1::BipartiteGraph{T}, bg2::BipartiteGraph{T}) where {T<:Integer}
74+
iseq = (bg1.ne == bg2.ne)
75+
iseq &= (bg1.fadjlist == bg2.fadjlist)
76+
iseq &= (bg1.badjlist == bg2.badjlist)
77+
iseq
78+
end
79+
80+
"""
81+
$(SIGNATURES)
82+
83+
Build an empty `BipartiteGraph` with `nsrcs` sources and `ndsts` destinations.
84+
"""
85+
BipartiteGraph(nsrcs::T, ndsts::T) where T = BipartiteGraph(0, map(_->T[], 1:nsrcs), map(_->T[], 1:ndsts))
86+
87+
Base.eltype(::Type{BipartiteGraph{I}}) where I = I
88+
Base.empty!(g::BipartiteGraph) = (foreach(empty!, g.fadjlist); foreach(empty!, g.badjlist); g.ne = 0; g)
89+
Base.length(::BipartiteGraph) = error("length is not well defined! Use `ne` or `nv`.")
90+
91+
if isdefined(LightGraphs, :has_contiguous_vertices)
92+
LightGraphs.has_contiguous_vertices(::Type{<:BipartiteGraph}) = false
93+
end
94+
LightGraphs.is_directed(::Type{<:BipartiteGraph}) = false
95+
LightGraphs.vertices(g::BipartiteGraph) = (𝑠vertices(g), 𝑑vertices(g))
96+
𝑠vertices(g::BipartiteGraph) = axes(g.fadjlist, 1)
97+
𝑑vertices(g::BipartiteGraph) = axes(g.badjlist, 1)
98+
has_𝑠vertex(g::BipartiteGraph, v::Integer) = v in 𝑠vertices(g)
99+
has_𝑑vertex(g::BipartiteGraph, v::Integer) = v in 𝑑vertices(g)
100+
𝑠neighbors(g::BipartiteGraph, i::Integer) = g.fadjlist[i]
101+
𝑑neighbors(g::BipartiteGraph, i::Integer) = g.badjlist[i]
102+
LightGraphs.ne(g::BipartiteGraph) = g.ne
103+
LightGraphs.nv(g::BipartiteGraph) = sum(length, vertices(g))
104+
LightGraphs.edgetype(g::BipartiteGraph{I}) where I = BipartiteEdge{I}
105+
106+
nsrcs(g::BipartiteGraph) = length(𝑠vertices(g))
107+
ndsts(g::BipartiteGraph) = length(𝑑vertices(g))
108+
109+
function LightGraphs.has_edge(g::BipartiteGraph, edge::BipartiteEdge)
110+
@unpack src, dst = edge
111+
(src in 𝑠vertices(g) && dst in 𝑑vertices(g)) || return false # edge out of bounds
112+
insorted(𝑠neighbors(src), dst)
113+
end
114+
115+
###
116+
### Populate
117+
###
118+
LightGraphs.add_edge!(g::BipartiteGraph, i::Integer, j::Integer) = add_edge!(g, BipartiteEdge(i, j))
119+
function LightGraphs.add_edge!(g::BipartiteGraph, edge::BipartiteEdge)
120+
@unpack fadjlist, badjlist = g
121+
verts = vertices(g)
122+
s, d = src(edge), dst(edge)
123+
(has_𝑠vertex(g, s) && has_𝑑vertex(g, d)) || error("edge ($edge) out of range.")
124+
@inbounds list = fadjlist[s]
125+
index = searchsortedfirst(list, d)
126+
@inbounds (index <= length(list) && list[index] == d) && return false # edge already in graph
127+
insert!(list, index, d)
128+
129+
g.ne += 1
130+
@inbounds list = badjlist[d]
131+
index = searchsortedfirst(list, s)
132+
insert!(list, index, s)
133+
return true # edge successfully added
134+
end
135+
136+
function LightGraphs.add_vertex!(g::BipartiteGraph{T}, type::VertType) where T
137+
if type === DST
138+
push!(g.badjlist, T[])
139+
elseif type === SRC
140+
push!(g.fadjlist, T[])
141+
else
142+
error("type ($type) must be either `DST` or `SRC`")
143+
end
144+
return true # vertex successfully added
145+
end
146+
147+
###
148+
### Edges iteration
149+
###
150+
LightGraphs.edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(ALL))
151+
𝑠edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(SRC))
152+
𝑑edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(DST))
153+
154+
struct BipartiteEdgeIter{T,G} <: LightGraphs.AbstractEdgeIter
155+
g::G
156+
type::Val{T}
157+
end
158+
159+
Base.length(it::BipartiteEdgeIter) = ne(it.g)
160+
Base.length(it::BipartiteEdgeIter{ALL}) = 2ne(it.g)
161+
162+
Base.eltype(it::BipartiteEdgeIter) = edgetype(it.g)
163+
164+
function Base.iterate(it::BipartiteEdgeIter{SRC,BipartiteGraph{T}}, state=(1, 1, SRC)) where T
165+
@unpack g = it
166+
neqs = nsrcs(g)
167+
neqs == 0 && return nothing
168+
eq, jvar = state
169+
170+
while eq <= neqs
171+
eq′ = eq
172+
vars = 𝑠neighbors(g, eq′)
173+
if jvar > length(vars)
174+
eq += 1
175+
jvar = 1
176+
continue
177+
end
178+
edge = BipartiteEdge(eq′, vars[jvar])
179+
state = (eq, jvar + 1, SRC)
180+
return edge, state
181+
end
182+
return nothing
183+
end
184+
185+
function Base.iterate(it::BipartiteEdgeIter{DST,BipartiteGraph{T}}, state=(1, 1, DST)) where T
186+
@unpack g = it
187+
nvars = ndsts(g)
188+
nvars == 0 && return nothing
189+
ieq, jvar = state
190+
191+
while jvar <= nvars
192+
eqs = 𝑑neighbors(g, jvar)
193+
if ieq > length(eqs)
194+
ieq = 1
195+
jvar += 1
196+
continue
197+
end
198+
edge = BipartiteEdge(eqs[ieq], jvar)
199+
state = (ieq + 1, jvar, DST)
200+
return edge, state
201+
end
202+
return nothing
203+
end
204+
205+
function Base.iterate(it::BipartiteEdgeIter{ALL,<:BipartiteGraph}, state=nothing)
206+
if state === nothing
207+
ss = iterate((@set it.type = Val(SRC)))
208+
elseif state[3] === SRC
209+
ss = iterate((@set it.type = Val(SRC)), state)
210+
elseif state[3] == DST
211+
ss = iterate((@set it.type = Val(DST)), state)
212+
end
213+
if ss === nothing && state[3] == SRC
214+
return iterate((@set it.type = Val(DST)))
215+
else
216+
return ss
217+
end
218+
end
219+
220+
###
221+
### Utils
222+
###
223+
function LightGraphs.incidence_matrix(g::BipartiteGraph, val=true)
224+
I = Int[]
225+
J = Int[]
226+
for i in 𝑠vertices(g), n in 𝑠neighbors(g, i)
227+
push!(I, i)
228+
push!(J, n)
229+
end
230+
S = sparse(I, J, val, nsrcs(g), ndsts(g))
231+
end

src/context_dsl.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ SymbolicUtils.@number_methods(Sym{Parameter{Real}},
1010
term(f, a, b), skipbasics)
1111

1212
SymbolicUtils.symtype(s::Symbolic{Parameter{T}}) where T = T
13-
SymbolicUtils.similarterm(t::Term{T}, f, args) where {T<:Parameter} = Term{T}(f, args)
13+
SymbolicUtils.similarterm(t::Term{<:Parameter}, f, args) = Term(f, args)
1414

1515
Base.convert(::Type{Num}, x::Symbolic{Parameter{T}}) where {T<:Number} = Num(x)
1616

src/differentials.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct Differential <: Function
2929
end
3030
(D::Differential)(x) = Term{symtype(x)}(D, [x])
3131
(D::Differential)(x::Num) = Num(D(value(x)))
32+
SymbolicUtils.promote_symtype(::Differential, x) = x
3233

3334
Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")")
3435

src/systems/abstractsystem.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,4 +297,3 @@ function (f::AbstractSysToExpr)(O)
297297
end
298298
return build_expr(:call, Any[operation(O); f.(arguments(O))])
299299
end
300-

0 commit comments

Comments
 (0)