Skip to content

Commit dd98466

Browse files
committed
Add solvable_graph basics
1 parent ac01392 commit dd98466

File tree

6 files changed

+46
-12
lines changed

6 files changed

+46
-12
lines changed

src/bipartite_graph.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,12 @@ end
130130
###
131131
### Populate
132132
###
133-
LightGraphs.add_edge!(g::BipartiteGraph, i::Integer, j::Integer) = add_edge!(g, BipartiteEdge(i, j))
134-
function LightGraphs.add_edge!(g::BipartiteGraph, edge::BipartiteEdge)
133+
struct NoMetadata
134+
end
135+
const NO_METADATA = NoMetadata()
136+
137+
LightGraphs.add_edge!(g::BipartiteGraph, i::Integer, j::Integer, md=NO_METADATA) = add_edge!(g, BipartiteEdge(i, j), md)
138+
function LightGraphs.add_edge!(g::BipartiteGraph, edge::BipartiteEdge, md=NO_METADATA)
135139
@unpack fadjlist, badjlist = g
136140
verts = vertices(g)
137141
s, d = src(edge), dst(edge)
@@ -140,6 +144,9 @@ function LightGraphs.add_edge!(g::BipartiteGraph, edge::BipartiteEdge)
140144
index = searchsortedfirst(list, d)
141145
@inbounds (index <= length(list) && list[index] == d) && return false # edge already in graph
142146
insert!(list, index, d)
147+
if md !== NO_METADATA
148+
insert!(g.metadata[s], index, md)
149+
end
143150

144151
g.ne += 1
145152
@inbounds list = badjlist[d]

src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ function Base.getproperty(sys::AbstractSystem, name::Symbol)
196196
end
197197

198198
sts = get_states(sys)
199+
@show sts
199200
i = findfirst(x->getname(x) == name, sts)
200201

201202
if i !== nothing

src/systems/alias_elimination.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ function fixpoint_sub(x, dict)
1010
return x
1111
end
1212

13-
function substitute_aliases(diffeqs, dict)
14-
lhss(diffeqs) .~ fixpoint_sub.(rhss(diffeqs), (dict,))
13+
function substitute_aliases(eqs, dict)
14+
sub = Base.Fix2(fixpoint_sub, dict)
15+
map(eq->eq.lhs ~ sub(eq.rhs), eqs)
1516
end
1617

1718
# Note that we reduce parameters, too
@@ -68,8 +69,13 @@ function maybe_alias(lhs, rhs, diff_vars, iv, conservative)
6869
end
6970
end
7071

71-
function alias_elimination(sys; conservative=true)
72+
function alias_elimination(sys)
7273
sys = flatten(sys)
74+
s = get_structure(sys)
75+
if !(s isa SystemStructure)
76+
sys = initialize_system_structure(sys)
77+
s = structure(sys)
78+
end
7379
iv = independent_variable(sys)
7480
eqs = equations(sys)
7581
diff_vars = filter(!isnothing, map(eqs) do eq
@@ -124,7 +130,7 @@ function alias_elimination(sys; conservative=true)
124130
@set! sys.eqs = substitute_aliases(neweqs, Dict(subs))
125131
@set! sys.states = newstates
126132
@set! sys.observed = [observed(sys); alias_eqs]
127-
return initialize_system_structure(sys)
133+
return
128134
end
129135

130136
"""

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ function collect_vars!(states, parameters, expr, iv)
178178
collect_var!(states, parameters, expr, iv)
179179
else
180180
for var in vars(expr)
181+
if istree(var) && operation(var) isa Differential
182+
var, _ = var_from_nested_derivative(var)
183+
end
181184
collect_var!(states, parameters, var, iv)
182185
end
183186
end

src/systems/systemstructure.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
module SystemStructures
22

33
using DataStructures
4-
using SymbolicUtils: istree, operation
4+
using SymbolicUtils: istree, operation, arguments
55
using ..ModelingToolkit
6-
import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten
7-
using SymbolicUtils: arguments
6+
import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, value
87
using ..BipartiteGraphs
98
using UnPack
109
using Setfield
@@ -46,7 +45,7 @@ export vartype, eqtype
4645

4746
struct SystemStructure
4847
dxvar_offset::Int
49-
fullvars::Vector # [xvar; dxvars; algvars]
48+
fullvars::Vector # [diffvars; dervars; algvars]
5049
varassoc::Vector{Int}
5150
algeqs::BitVector
5251
graph::BipartiteGraph{Int,Nothing}
@@ -122,14 +121,17 @@ function init_graph(sys)
122121
for (i, eq) in enumerate(eqs)
123122
vars = OrderedSet()
124123
vars!(vars, eq)
124+
isalgeq = true
125125
for var in vars
126126
if istree(var) && operation(var) isa Differential
127+
isalgeq = false
127128
diffvar = arguments(var)[1]
128129
@assert !(diffvar isa Differential) "The equation [ $eq ] is not first order"
129130
push!(dervars, var)
130131
push!(diffvars, diffvar)
131132
end
132133
end
134+
algeqs[i] = isalgeq
133135
varsadj[i] = vars
134136
end
135137

@@ -142,12 +144,23 @@ function init_graph(sys)
142144
nvars = length(fullvars)
143145
idxmap = Dict(fullvars .=> 1:nvars)
144146
graph = BipartiteGraph(neqs, nvars)
145-
solvable_graph = BipartiteGraph(neqs, nvars, metadata=Vector{Int}[])
147+
solvable_graph = BipartiteGraph(neqs, nvars, metadata=map(_->Int[], 1:neqs))
146148

147149
for (i, vs) in enumerate(varsadj)
150+
eq = eqs[i]
148151
for v in vs
149152
j = get(idxmap, v, nothing)
150-
j === nothing || add_edge!(graph, i, idxmap[v])
153+
if j !== nothing
154+
add_edge!(graph, i, idxmap[v])
155+
j > algvar_offset || continue
156+
D = Differential(fullvars[j])
157+
c = value(expand_derivatives(D(eq.rhs - eq.lhs), false))
158+
if c isa Number && c != 0
159+
# 0 here is a sentinel value for non-integer coefficients
160+
coeff = c isa Integer ? c : 0
161+
add_edge!(solvable_graph, i, j, coeff)
162+
end
163+
end
151164
end
152165
end
153166

test/reduction.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ eqs1 = [
6868

6969
lorenz = name -> ODESystem(eqs1,t,name=name)
7070
lorenz1 = lorenz(:lorenz1)
71+
ss = ModelingToolkit.get_structure(initialize_system_structure(lorenz1))
72+
@test isequal(ss.fullvars, [x, y, z, D(x), D(y), D(z), F, u])
73+
@test ss.solvable_graph.fadjlist == [[7], [8], [], [8]]
74+
@test ss.solvable_graph.metadata == [[1], [1], [], [1]]
7175
lorenz2 = lorenz(:lorenz2)
7276

7377
connected = ODESystem([s ~ a + lorenz1.x

0 commit comments

Comments
 (0)