Skip to content

Commit ca2347a

Browse files
committed
More sophisticated solvable_graph
1 parent f0e16cf commit ca2347a

File tree

1 file changed

+70
-18
lines changed

1 file changed

+70
-18
lines changed

src/systems/systemstructure.jl

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module SystemStructures
22

33
using DataStructures
4-
using SymbolicUtils: istree, operation, arguments
4+
using SymbolicUtils: istree, operation, arguments, Symbolic
55
using ..ModelingToolkit
66
import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, value
77
using ..BipartiteGraphs
@@ -50,13 +50,15 @@ struct SystemStructure
5050
algeqs::BitVector
5151
graph::BipartiteGraph{Int,Nothing}
5252
solvable_graph::BipartiteGraph{Int,Vector{Vector{Int}}}
53+
linear_equations::Vector{Int}
5354
assign::Vector{Int}
5455
inv_assign::Vector{Int}
5556
scc::Vector{Vector{Int}}
5657
partitions::Vector{NTuple{4, Vector{Int}}}
5758
end
5859

5960
diffvars_range(s::SystemStructure) = 1:s.dxvar_offset
61+
# TODO: maybe dervars should be in the end.
6062
dervars_range(s::SystemStructure) = s.dxvar_offset+1:2s.dxvar_offset
6163
algvars_range(s::SystemStructure) = 2s.dxvar_offset+1:length(s.fullvars)
6264

@@ -79,7 +81,10 @@ isdiffeq(s::SystemStructure, eq::Integer) = !isalgeq(s, eq)
7981
eqtype(s::SystemStructure, eq::Integer)::EquationType = isalgeq(s, eq) ? ALGEBRAIC_EQUATION : DIFFERENTIAL_EQUATION
8082

8183
function initialize_system_structure(sys)
82-
sys, dxvar_offset, fullvars, varassoc, algeqs, graph, solvable_graph = init_graph(flatten(sys))
84+
sys, dxvar_offset, fullvars, varassoc, algeqs, graph = init_graph(flatten(sys))
85+
86+
solvable_graph = BipartiteGraph(neqs, nvars, metadata=map(_->Int[], 1:nsrcs(graph)))
87+
8388
@set sys.structure = SystemStructure(
8489
dxvar_offset,
8590
fullvars,
@@ -89,6 +94,7 @@ function initialize_system_structure(sys)
8994
solvable_graph,
9095
Int[],
9196
Int[],
97+
Int[],
9298
Vector{Int}[],
9399
NTuple{4, Vector{Int}}[]
94100
)
@@ -144,28 +150,74 @@ function init_graph(sys)
144150
nvars = length(fullvars)
145151
idxmap = Dict(fullvars .=> 1:nvars)
146152
graph = BipartiteGraph(neqs, nvars)
147-
solvable_graph = BipartiteGraph(neqs, nvars, metadata=map(_->Int[], 1:neqs))
148153

149-
for (i, vs) in enumerate(varsadj)
150-
eq = eqs[i]
154+
vs = Set()
155+
for (i, eq) in enumerate(eqs)
156+
# TODO: custom vars that handles D(x)
157+
# TODO: add checks here
158+
lhs = eq.lhs
159+
if isdiffeq(eq)
160+
v = lhs
161+
haskey(idxmap, v) && add_edge!(graph, i, idxmap[v])
162+
else
163+
vars!(vs, lhs)
164+
end
165+
vars!(vs, eq.rhs)
151166
for v in vs
152-
j = get(idxmap, v, nothing)
153-
if j !== nothing
154-
add_edge!(graph, i, idxmap[v])
155-
j > algvar_offset || continue
156-
D = Differential(fullvars[j])
157-
c = value(expand_derivatives(D(eq.rhs - eq.lhs), false))
158-
if c isa Number && c != 0
159-
# 0 here is a sentinel value for non-integer coefficients
160-
coeff = c isa Integer ? c : 0
161-
add_edge!(solvable_graph, i, j, coeff)
162-
end
163-
end
167+
haskey(idxmap, v) && add_edge!(graph, i, idxmap[v])
164168
end
169+
empty!(vs)
165170
end
166171

167172
varassoc = Int[(1:dxvar_offset) .+ dxvar_offset; zeros(Int, length(fullvars) - dxvar_offset)] # variable association list
168-
sys, dxvar_offset, fullvars, varassoc, algeqs, graph, solvable_graph
173+
sys, dxvar_offset, fullvars, varassoc, algeqs, graph
174+
end
175+
176+
function find_solvables!(sys)
177+
s = structure(sys)
178+
@unpack fullvars, graph, solvable_graph, linear_equations = s
179+
eqs = equations(sys)
180+
empty!(solvable_graph); empty!(linear_equations)
181+
for (i, eq) in enumerate(eqs); isdiffeq(eq) && continue
182+
term = value(eq.rhs - eq.lhs)
183+
linear_term = 0
184+
all_int_algvars = true
185+
for j in 𝑠neighbors(graph, i)
186+
if !isalgvar(s, j)
187+
all_int_algvars = false
188+
continue
189+
end
190+
var = fullvars[j]
191+
c = expand_derivatives(Differential(var)(term), false)
192+
# test if `var` is linear in `eq`.
193+
if !(c isa Symbolic) && c isa Number && c != 0
194+
if isinteger(c)
195+
c = convert(Integer, c)
196+
else
197+
all_int_algvars = false
198+
end
199+
linear_term += c * var
200+
add_edge!(solvable_graph, i, j, c)
201+
end
202+
end
203+
204+
# Check if there are only algebraic variables and the equation is both
205+
# linear and homogeneous, i.e. it is in the form of
206+
#
207+
# ``∑ c_i * a_i = 0``,
208+
#
209+
# where ``c_i`` ∈ ℤ and ``a_i`` denotes algebraic variables.
210+
if all_int_algvars && isequal(linear_term, term)
211+
push!(linear_equations, i)
212+
else
213+
# We use 0 as a sentinel value for a nonlinear or non-integer term.
214+
215+
# Don't move the reference, because it might lead to pointer
216+
# invalidations.
217+
fill!(solvable_graph.metadata[i], 0)
218+
end
219+
end
220+
s
169221
end
170222

171223
end # module

0 commit comments

Comments
 (0)