Skip to content

Commit ee1ec70

Browse files
authored
Merge pull request #767 from SciML/myb/print
Modularize BipartiteGraphs and SystemStructures
2 parents 211b7af + a85f702 commit ee1ec70

File tree

5 files changed

+78
-15
lines changed

5 files changed

+78
-15
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2121
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2222
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
2323
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
24+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2425
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2526
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
2627
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -49,6 +50,7 @@ LightGraphs = "1.3"
4950
MacroTools = "0.5"
5051
NaNMath = "0.3"
5152
RecursiveArrayTools = "2.3"
53+
Reexport = "1"
5254
Requires = "1.0"
5355
RuntimeGeneratedFunctions = "0.4, 0.5"
5456
SafeTestsets = "0.0.1"

src/ModelingToolkit.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using StaticArrays, LinearAlgebra, SparseArrays, LabelledArrays
55
using Latexify, Unitful, ArrayInterface
66
using MacroTools
77
using UnPack: @unpack
8+
using Setfield
89
using DiffEqJump
910
using DataStructures
1011
using SpecialFunctions, NaNMath
@@ -202,6 +203,7 @@ Get the set of parameters variables for the given system.
202203
function parameters end
203204

204205
include("bipartite_graph.jl")
206+
using .BipartiteGraphs
205207

206208
include("variables.jl")
207209
include("context_dsl.jl")
@@ -216,7 +218,6 @@ include("domains.jl")
216218
include("register_function.jl")
217219

218220
include("systems/abstractsystem.jl")
219-
include("systems/systemstructure.jl")
220221

221222
include("systems/diffeqs/odesystem.jl")
222223
include("systems/diffeqs/sdesystem.jl")
@@ -239,6 +240,9 @@ include("systems/pde/pdesystem.jl")
239240
include("systems/reaction/reactionsystem.jl")
240241
include("systems/dependency_graphs.jl")
241242

243+
include("systems/systemstructure.jl")
244+
using .SystemStructures
245+
242246
include("systems/reduction.jl")
243247

244248
include("latexify_recipes.jl")

src/bipartite_graph.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
module BipartiteGraphs
2+
3+
export BipartiteEdge, BipartiteGraph
4+
5+
export 𝑠vertices, 𝑑vertices, has_𝑠vertex, has_𝑑vertex, 𝑠neighbors, 𝑑neighbors,
6+
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST
7+
8+
using DocStringExtensions
9+
using Reexport
110
using UnPack
211
using SparseArrays
3-
using LightGraphs
12+
@reexport using LightGraphs
413
using Setfield
514

615
###
@@ -229,3 +238,5 @@ function LightGraphs.incidence_matrix(g::BipartiteGraph, val=true)
229238
end
230239
S = sparse(I, J, val, nsrcs(g), ndsts(g))
231240
end
241+
242+
end # module

src/solve.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,34 +56,37 @@ end
5656
# return the coefficient matrix `A` and a
5757
# vector of constants (possibly symbolic) `b` such that
5858
# A \ b will solve the equations for the vars
59-
function A_b(eqs::AbstractArray, vars::AbstractArray)
59+
function A_b(eqs::AbstractArray, vars::AbstractArray, check)
6060
exprs = rhss(eqs) .- lhss(eqs)
61-
for ex in exprs
62-
@assert islinear(ex, vars)
61+
if check
62+
for ex in exprs
63+
@assert islinear(ex, vars)
64+
end
6365
end
6466
A = jacobian(exprs, vars)
6567
b = A * vars - exprs
6668
A, b
6769
end
68-
function A_b(eq, var)
70+
function A_b(eq, var, check)
6971
ex = eq.rhs - eq.lhs
70-
@assert islinear(ex, [var])
72+
check && @assert islinear(ex, [var])
7173
a = expand_derivatives(Differential(var)(ex))
7274
b = a * var - ex
7375
a, b
7476
end
7577

7678
"""
77-
solve_for(eqs::Vector, vars::Vector)
79+
solve_for(eqs::Vector, vars::Vector; simplify=true, check=true)
7880
7981
Solve the vector of equations `eqs` for a set of variables `vars`.
8082
8183
Assumes `length(eqs) == length(vars)`
8284
83-
Currently only works if all equations are linear.
85+
Currently only works if all equations are linear. `check` if the expr is linear
86+
w.r.t `vars`.
8487
"""
85-
function solve_for(eqs, vars; simplify=true)
86-
A, b = A_b(eqs, vars)
88+
function solve_for(eqs, vars; simplify=true, check=true)
89+
A, b = A_b(eqs, vars, check)
8790
#TODO: we need to make sure that `solve_for(eqs, vars)` contains no `vars`
8891
_solve(A, b, simplify)
8992
end

src/systems/systemstructure.jl

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
module SystemStructures
2+
3+
using ..ModelingToolkit
4+
import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten
5+
using SymbolicUtils: arguments
6+
using ..BipartiteGraphs
7+
using UnPack
8+
using Setfield
19
using SparseArrays
210

311
#=
@@ -27,10 +35,18 @@ for v in 𝑣vertices(graph); active_𝑣vertices[v] || continue
2735
end
2836
=#
2937

38+
export SystemStructure, initialize_system_structure
39+
export diffvars_range, dervars_range, algvars_range
40+
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq
41+
export DIFFERENTIAL_VARIABLE, ALGEBRAIC_VARIABLE, DERIVATIVE_VARIABLE
42+
export DIFFERENTIAL_EQUATION, ALGEBRAIC_EQUATION
43+
export vartype, eqtype
44+
3045
struct SystemStructure
3146
dxvar_offset::Int
3247
fullvars::Vector # [xvar; dxvars; algvars]
3348
varassoc::Vector{Int}
49+
algeqs::BitVector
3450
graph::BipartiteGraph{Int}
3551
solvable_graph::BipartiteGraph{Int}
3652
assign::Vector{Int}
@@ -39,12 +55,35 @@ struct SystemStructure
3955
partitions::Vector{NTuple{4, Vector{Int}}}
4056
end
4157

58+
diffvars_range(s::SystemStructure) = 1:s.dxvar_offset
59+
dervars_range(s::SystemStructure) = s.dxvar_offset+1:2s.dxvar_offset
60+
algvars_range(s::SystemStructure) = 2s.dxvar_offset+1:length(s.fullvars)
61+
62+
isdiffvar(s::SystemStructure, var::Integer) = var in diffvars_range(s)
63+
isdervar(s::SystemStructure, var::Integer) = var in dervars_range(s)
64+
isalgvar(s::SystemStructure, var::Integer) = var in algvars_range(s)
65+
66+
@enum VariableType DIFFERENTIAL_VARIABLE ALGEBRAIC_VARIABLE DERIVATIVE_VARIABLE
67+
68+
function vartype(s::SystemStructure, var::Integer)::VariableType
69+
isdiffvar(s, var) ? DIFFERENTIAL_VARIABLE :
70+
isdervar(s, var) ? DERIVATIVE_VARIABLE :
71+
isalgvar(s, var) ? ALGEBRAIC_VARIABLE : error("Variable $var out of bounds")
72+
end
73+
74+
@enum EquationType DIFFERENTIAL_EQUATION ALGEBRAIC_EQUATION
75+
76+
isalgeq(s::SystemStructure, eq::Integer) = s.algeqs[eq]
77+
isdiffeq(s::SystemStructure, eq::Integer) = !isalgeq(s, eq)
78+
eqtype(s::SystemStructure, eq::Integer)::EquationType = isalgeq(s, eq) ? ALGEBRAIC_EQUATION : DIFFERENTIAL_EQUATION
79+
4280
function initialize_system_structure(sys)
43-
sys, dxvar_offset, fullvars, varassoc, graph, solvable_graph = init_graph(flatten(sys))
81+
sys, dxvar_offset, fullvars, varassoc, algeqs, graph, solvable_graph = init_graph(flatten(sys))
4482
@set sys.structure = SystemStructure(
4583
dxvar_offset,
4684
fullvars,
4785
varassoc,
86+
algeqs,
4887
graph,
4988
solvable_graph,
5089
Int[],
@@ -74,8 +113,10 @@ end
74113
function collect_variables(sys)
75114
dxvars = []
76115
eqs = equations(sys)
116+
algeqs = trues(length(eqs))
77117
for (i, eq) in enumerate(eqs)
78118
if isdiffeq(eq)
119+
algeqs[i] = false
79120
lhs = eq.lhs
80121
# Make sure that the LHS is a first order derivative of a var.
81122
@assert !(arguments(lhs)[1] isa Differential) "The equation $eq is not first order"
@@ -86,11 +127,11 @@ function collect_variables(sys)
86127

87128
xvars = (first var_from_nested_derivative).(dxvars)
88129
algvars = setdiff(states(sys), xvars)
89-
return xvars, dxvars, algvars
130+
return xvars, dxvars, algvars, algeqs
90131
end
91132

92133
function init_graph(sys)
93-
xvars, dxvars, algvars = collect_variables(sys)
134+
xvars, dxvars, algvars, algeqs = collect_variables(sys)
94135
dxvar_offset = length(xvars)
95136
algvar_offset = 2dxvar_offset
96137

@@ -119,5 +160,7 @@ function init_graph(sys)
119160
end
120161

121162
varassoc = Int[(1:dxvar_offset) .+ dxvar_offset; zeros(Int, length(fullvars) - dxvar_offset)] # variable association list
122-
sys, dxvar_offset, fullvars, varassoc, graph, solvable_graph
163+
sys, dxvar_offset, fullvars, varassoc, algeqs, graph, solvable_graph
123164
end
165+
166+
end # module

0 commit comments

Comments
 (0)