Skip to content

Commit 4e5dd0b

Browse files
authored
Merge pull request #2187 from SciML/aliasfix
Add back rank2 (highest diffed variables) in alias elimination
2 parents 1f0be82 + 586a6b2 commit 4e5dd0b

File tree

9 files changed

+56
-31
lines changed

9 files changed

+56
-31
lines changed

examples/serial_inductor.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,20 @@ eqs = [connect(source.p, resistor.p)
1414

1515
@named ll_model = ODESystem(eqs, t)
1616
ll_model = compose(ll_model, [source, resistor, inductor1, inductor2, ground])
17+
18+
@named source = ConstantVoltage(V = 10.0)
19+
@named resistor1 = Resistor(R = 1.0)
20+
@named resistor2 = Resistor(R = 1.0)
21+
@named inductor1 = Inductor(L = 1.0e-2)
22+
@named inductor2 = Inductor(L = 2.0e-2)
23+
@named ground = Ground()
24+
25+
eqs = [connect(source.p, inductor1.p)
26+
connect(inductor1.n, resistor1.p)
27+
connect(inductor1.n, resistor2.p)
28+
connect(resistor1.n, resistor2.n)
29+
connect(resistor2.n, inductor2.p)
30+
connect(source.n, inductor2.n)
31+
connect(inductor2.n, ground.g)]
32+
@named ll2_model = ODESystem(eqs, t)
33+
ll2_model = compose(ll2_model, [source, resistor1, resistor2, inductor1, inductor2, ground])

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ export tearing_assignments, tearing_substitution
5454
export torn_system_jacobian_sparsity
5555
export full_equations
5656
export but_ordered_incidence, lowest_order_variable_mask, highest_order_variable_mask
57+
export computed_highest_diff_variables
5758

5859
include("utils.jl")
5960
include("pantelides.jl")

src/structural_transformation/pantelides.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
7070
end
7171

7272
"""
73-
computed_highest_diff_variables(var_to_diff)
73+
computed_highest_diff_variables(structure)
7474
7575
Computes which variables are the "highest-differentiated" for purposes of
7676
pantelides. Ordinarily this is relatively straightforward. However, in our

src/systems/alias_elimination.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
264264
return linear_variables
265265
end
266266

267-
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
267+
function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
268+
@unpack graph, var_to_diff = structure
268269
mm = copy(mm_orig)
269270
linear_equations_set = BitSet(mm_orig.nzrows)
270271

@@ -279,6 +280,7 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) wher
279280
v -> var_to_diff[v] === nothing === invview(var_to_diff)[v]
280281
end
281282
is_linear_variables = is_algebraic.(1:length(var_to_diff))
283+
is_highest_diff = computed_highest_diff_variables(structure)
282284
for i in 𝑠vertices(graph)
283285
# only consider linear algebraic equations
284286
(i in linear_equations_set && all(is_algebraic, 𝑠neighbors(graph, i))) &&
@@ -291,25 +293,31 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) wher
291293

292294
local bar
293295
try
294-
bar = do_bareiss!(mm, mm_orig, is_linear_variables)
296+
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff)
295297
catch e
296298
e isa OverflowError || rethrow(e)
297299
mm = convert(SparseMatrixCLIL{BigInt, Ti}, mm_orig)
298-
bar = do_bareiss!(mm, mm_orig, is_linear_variables)
300+
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff)
299301
end
300302

301303
return mm, solvable_variables, bar
302304
end
303305

304-
function do_bareiss!(M, Mold, is_linear_variables)
306+
function do_bareiss!(M, Mold, is_linear_variables, is_highest_diff)
305307
rank1r = Ref{Union{Nothing, Int}}(nothing)
308+
rank2r = Ref{Union{Nothing, Int}}(nothing)
306309
find_pivot = let rank1r = rank1r
307310
(M, k) -> begin
308311
if rank1r[] === nothing
309312
r = find_masked_pivot(is_linear_variables, M, k)
310313
r !== nothing && return r
311314
rank1r[] = k - 1
312315
end
316+
if rank2r[] === nothing
317+
r = find_masked_pivot(is_highest_diff, M, k)
318+
r !== nothing && return r
319+
rank2r[] = k - 1
320+
end
313321
# TODO: It would be better to sort the variables by
314322
# derivative order here to enable more elimination
315323
# opportunities.
@@ -334,15 +342,19 @@ function do_bareiss!(M, Mold, is_linear_variables)
334342
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
335343
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
336344

337-
rank2, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
345+
rank3, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
346+
rank2 = something(rank2r[], rank3)
338347
rank1 = something(rank1r[], rank2)
339-
(rank1, rank2, pivots)
348+
(rank1, rank2, rank3, pivots)
340349
end
341350

342-
function simple_aliases!(ils, graph, solvable_graph, eq_to_diff, var_to_diff)
343-
ils, solvable_variables, (rank1, rank2, pivots) = aag_bareiss!(graph,
344-
var_to_diff,
345-
ils)
351+
function alias_eliminate_graph!(state::TransformationState, ils::SparseMatrixCLIL)
352+
@unpack structure = state
353+
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
354+
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
355+
# subsystem of the system we're interested in.
356+
#
357+
ils, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(structure, ils)
346358

347359
## Step 2: Simplify the system using the Bareiss factorization
348360
rk1vars = BitSet(@view pivots[1:rank1])
@@ -362,14 +374,6 @@ function simple_aliases!(ils, graph, solvable_graph, eq_to_diff, var_to_diff)
362374
return ils
363375
end
364376

365-
function alias_eliminate_graph!(state::TransformationState, ils::SparseMatrixCLIL)
366-
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
367-
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
368-
# subsystem of the system we're interested in.
369-
#
370-
return simple_aliases!(ils, graph, solvable_graph, eq_to_diff, var_to_diff)
371-
end
372-
373377
function exactdiv(a::Integer, b)
374378
d, r = divrem(a, b)
375379
@assert r == 0

test/clock.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ ci, varmap = infer_clocks(sys)
6767
eqmap = ci.eq_domain
6868
tss, inputs = ModelingToolkit.split_system(deepcopy(ci))
6969
sss, = SystemStructures._structural_simplify!(deepcopy(tss[1]), (inputs[1], ()))
70-
@test equations(sss) == [D(x) ~ u - y]
70+
@test equations(sss) == [D(x) ~ u - x]
7171
sss, = SystemStructures._structural_simplify!(deepcopy(tss[2]), (inputs[2], ()))
7272
@test isempty(equations(sss))
7373
@test observed(sss) == [yd ~ Sample(t, dt)(y); r ~ 1.0; ud ~ kp * (r - yd)]

test/components.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ u0 = states(sys) .=> 0
159159
prob = DAEProblem(sys, Differential(t).(states(sys)) .=> 0, u0, (0, 0.5))
160160
@test_nowarn sol = solve(prob, DFBDF())
161161

162+
sys2 = structural_simplify(ll2_model)
163+
@test length(equations(sys2)) == 3
164+
u0 = states(sys) .=> 0
165+
prob = ODEProblem(sys, u0, (0, 10.0))
166+
@test_nowarn sol = solve(prob, FBDF())
167+
162168
@variables t x1(t) x2(t) x3(t) x4(t)
163169
D = Differential(t)
164170
@named sys1_inner = ODESystem([D(x1) ~ x1], t)

test/linearize.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,12 @@ if VERSION >= v"1.8"
207207

208208
@named model = ODESystem(eqs, t, [], []; systems = [link1, cart, force, fixed])
209209
def = ModelingToolkit.defaults(model)
210-
def[link1.y1] = 0
210+
for s in states(model)
211+
def[s] = 0
212+
end
211213
def[link1.x1] = 10
214+
def[link1.fy1] = -def[link1.g] * def[link1.m]
212215
def[link1.A] = -pi / 2
213-
def[link1.dA] = 0
214-
def[cart.s] = 0
215-
def[force.flange.v] = 0
216216
lin_outputs = [cart.s, cart.v, link1.A, link1.dA]
217217
lin_inputs = [force.f.u]
218218

@@ -237,7 +237,9 @@ if VERSION >= v"1.8"
237237
def = merge(def, Dict(x => 0.0 for x in dummyder))
238238

239239
@test substitute(lsyss.A, def) lsys.A
240-
@test substitute(lsyss.B, def) lsys.B
240+
# We cannot pivot symbolically, so the part where a linear solve is required
241+
# is not reliable.
242+
@test substitute(lsyss.B, def)[1:6, :] lsys.B[1:6, :]
241243
@test substitute(lsyss.C, def) lsys.C
242244
@test substitute(lsyss.D, def) lsys.D
243245
end

test/reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ ss = structural_simplify(sys)
294294
@test isempty(equations(ss))
295295
@test sort(string.(observed(ss))) == ["x(t) ~ 0.0"
296296
"xˍt(t) ~ 0.0"
297-
"y(t) ~ xˍt(t)"]
297+
"y(t) ~ xˍt(t) - x(t)"]
298298

299299
eqs = [D(D(x)) ~ -x]
300300
@named sys = ODESystem(eqs, t, [x], [])

test/structural_transformation/tearing.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,6 @@ eqs = [
138138
0 ~ x + z,
139139
]
140140
@named nlsys = NonlinearSystem(eqs, [x, y, z], [])
141-
let (mm, _, _) = ModelingToolkit.aag_bareiss(nlsys)
142-
@test mm == [-1 1 0;
143-
0 -1 -1;
144-
0 0 0]
145-
end
146141

147142
newsys = tearing(nlsys)
148143
@test length(equations(newsys)) <= 1

0 commit comments

Comments
 (0)