Skip to content

Commit 770c78c

Browse files
committed
Construct linear matrix and solvable graph at the same time
1 parent c9643fb commit 770c78c

File tree

5 files changed

+62
-62
lines changed

5 files changed

+62
-62
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2222
get_postprocess_fbody, vars!,
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
25-
AliasGraph, filter_kwargs, lower_varname, setio
25+
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL
2626

2727
using ModelingToolkit.BipartiteGraphs
2828
import .BipartiteGraphs: invview, complete
@@ -43,7 +43,8 @@ using NonlinearSolve
4343
export tearing, partial_state_selection, dae_index_lowering, check_consistency
4444
export dummy_derivative
4545
export build_torn_function, build_observed_function, ODAEProblem
46-
export sorted_incidence_matrix, pantelides!, tearing_reassemble, find_solvables!
46+
export sorted_incidence_matrix, pantelides!, tearing_reassemble, find_solvables!,
47+
linear_subsys_adjmat!
4748
export tearing_assignments, tearing_substitution
4849
export torn_system_jacobian_sparsity
4950
export full_equations

src/structural_transformation/symbolics_tearing.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
6262
eq_diff = eq_derivative_graph!(s, ieq)
6363

6464
sys = ts.sys
65-
D = Differential(get_iv(sys))
6665
eq = equations(ts)[ieq]
67-
eq = ModelingToolkit.expand_derivatives(0 ~ D(eq.rhs - eq.lhs))
66+
eq = 0 ~ ModelingToolkit.derivative(eq.rhs - eq.lhs, get_iv(sys))
6867
push!(equations(ts), eq)
6968
# Analyze the new equation and update the graph/solvable_graph
7069
# First, copy the previous incidence and add the derivative terms.

src/structural_transformation/utils.jl

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,20 +160,24 @@ end
160160
### Structural and symbolic utilities
161161
###
162162

163-
function find_eq_solvables!(state::TearingState, ieq; may_be_zero = false,
163+
function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = nothing;
164+
may_be_zero = false,
164165
allow_symbolic = false, allow_parameter = true, kwargs...)
165166
fullvars = state.fullvars
166167
@unpack graph, solvable_graph = state.structure
167168
eq = equations(state)[ieq]
168169
term = value(eq.rhs - eq.lhs)
169-
to_rm = Int[]
170+
all_int_vars = true
171+
coeffs === nothing || empty!(coeffs)
172+
empty!(to_rm)
170173
for j in 𝑠neighbors(graph, ieq)
171174
var = fullvars[j]
172-
isirreducible(var) && continue
175+
isirreducible(var) && (all_int_vars = false; continue)
173176
a, b, islinear = linear_expansion(term, var)
174-
a = unwrap(a)
175-
islinear || continue
177+
a, b = unwrap(a), unwrap(b)
178+
islinear || (all_int_vars = false; continue)
176179
if a isa Symbolic
180+
all_int_vars = false
177181
if !allow_symbolic
178182
if allow_parameter
179183
all(ModelingToolkit.isparameter, vars(a)) || continue
@@ -184,7 +188,18 @@ function find_eq_solvables!(state::TearingState, ieq; may_be_zero = false,
184188
add_edge!(solvable_graph, ieq, j)
185189
continue
186190
end
187-
(a isa Number) || continue
191+
if !(a isa Number)
192+
all_int_vars = false
193+
continue
194+
end
195+
# When the expression is linear with numeric `a`, then we can safely
196+
# only consider `b` for the following iterations.
197+
term = b
198+
if isone(abs(a))
199+
coeffs === nothing || push!(coeffs, convert(Int, a))
200+
else
201+
all_int_vars = false
202+
end
188203
if a != 0
189204
add_edge!(solvable_graph, ieq, j)
190205
else
@@ -198,19 +213,54 @@ function find_eq_solvables!(state::TearingState, ieq; may_be_zero = false,
198213
for j in to_rm
199214
rem_edge!(graph, ieq, j)
200215
end
216+
all_int_vars, term
201217
end
202218

203219
function find_solvables!(state::TearingState; kwargs...)
204220
@assert state.structure.solvable_graph === nothing
205221
eqs = equations(state)
206222
graph = state.structure.graph
207223
state.structure.solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
224+
to_rm = Int[]
208225
for ieq in 1:length(eqs)
209-
find_eq_solvables!(state, ieq; kwargs...)
226+
find_eq_solvables!(state, ieq, to_rm; kwargs...)
210227
end
211228
return nothing
212229
end
213230

231+
function linear_subsys_adjmat!(state::TransformationState)
232+
graph = state.structure.graph
233+
if state.structure.solvable_graph === nothing
234+
state.structure.solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
235+
end
236+
linear_equations = Int[]
237+
eqs = equations(state.sys)
238+
eadj = Vector{Int}[]
239+
cadj = Vector{Int}[]
240+
coeffs = Int[]
241+
to_rm = Int[]
242+
for i in eachindex(eqs)
243+
all_int_vars, rhs = find_eq_solvables!(state, i, to_rm, coeffs)
244+
245+
# Check if all states in the equation is both linear and homogeneous,
246+
# i.e. it is in the form of
247+
#
248+
# ``∑ c_i * v_i = 0``,
249+
#
250+
# where ``c_i`` ∈ ℤ and ``v_i`` denotes states.
251+
if all_int_vars && Symbolics._iszero(rhs)
252+
push!(linear_equations, i)
253+
push!(eadj, copy(𝑠neighbors(graph, i)))
254+
push!(cadj, copy(coeffs))
255+
end
256+
end
257+
258+
mm = SparseMatrixCLIL(nsrcs(graph),
259+
ndsts(graph),
260+
linear_equations, eadj, cadj)
261+
return mm
262+
end
263+
214264
highest_order_variable_mask(ts) =
215265
let v2d = ts.structure.var_to_diff
216266
v -> isempty(outneighbors(v2d, v))

src/systems/alias_elimination.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Graphs.Experimental.Traversals
55
const KEEP = typemin(Int)
66

77
function alias_eliminate_graph!(state::TransformationState)
8-
mm = linear_subsys_adjmat(state)
8+
mm = linear_subsys_adjmat!(state)
99
if size(mm, 1) == 0
1010
ag = AliasGraph(ndsts(state.structure.graph))
1111
return ag, mm, ag, mm, BitSet() # No linear subsystems

src/systems/systemstructure.jl

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -351,56 +351,6 @@ function TearingState(sys; quick_cancel = false, check = true)
351351
complete(graph), nothing), Any[])
352352
end
353353

354-
function linear_subsys_adjmat(state::TransformationState)
355-
fullvars = state.fullvars
356-
graph = state.structure.graph
357-
is_linear_equations = falses(nsrcs(graph))
358-
eqs = equations(state.sys)
359-
eadj = Vector{Int}[]
360-
cadj = Vector{Int}[]
361-
coeffs = Int[]
362-
for (i, eq) in enumerate(eqs)
363-
empty!(coeffs)
364-
linear_term = 0
365-
all_int_vars = true
366-
367-
term = value(eq.rhs - eq.lhs)
368-
for j in 𝑠neighbors(graph, i)
369-
var = fullvars[j]
370-
a, b, islinear = linear_expansion(term, var)
371-
a = unwrap(a)
372-
if islinear && !(a isa Symbolic) && a isa Number && !isirreducible(var)
373-
if a == 1 || a == -1
374-
a = convert(Integer, a)
375-
linear_term += a * var
376-
push!(coeffs, a)
377-
else
378-
all_int_vars = false
379-
end
380-
end
381-
end
382-
383-
# Check if all states in the equation is both linear and homogeneous,
384-
# i.e. it is in the form of
385-
#
386-
# ``∑ c_i * v_i = 0``,
387-
#
388-
# where ``c_i`` ∈ ℤ and ``v_i`` denotes states.
389-
if all_int_vars && isequal(linear_term, term)
390-
is_linear_equations[i] = true
391-
push!(eadj, copy(𝑠neighbors(graph, i)))
392-
push!(cadj, copy(coeffs))
393-
else
394-
is_linear_equations[i] = false
395-
end
396-
end
397-
398-
linear_equations = findall(is_linear_equations)
399-
return SparseMatrixCLIL(nsrcs(graph),
400-
ndsts(graph),
401-
linear_equations, eadj, cadj)
402-
end
403-
404354
using .BipartiteGraphs: Label, BipartiteAdjacencyList
405355
struct SystemStructurePrintMatrix <:
406356
AbstractMatrix{Union{Label, Int, BipartiteAdjacencyList}}

0 commit comments

Comments
 (0)