Skip to content

Commit 6dcec3f

Browse files
committed
Add alias_eliminate_graph
1 parent 12e3280 commit 6dcec3f

File tree

1 file changed

+105
-174
lines changed

1 file changed

+105
-174
lines changed

src/systems/alias_elimination.jl

Lines changed: 105 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -1,148 +1,19 @@
11
using SymbolicUtils: Rewriters
22

3-
function fixpoint_sub(x, dict)
4-
y = substitute(x, dict)
5-
while !isequal(x, y)
6-
y = x
7-
x = substitute(y, dict)
8-
end
9-
10-
return x
11-
end
12-
13-
function substitute_aliases(eqs, dict)
14-
sub = Base.Fix2(fixpoint_sub, dict)
15-
map(eq->eq.lhs ~ sub(eq.rhs), eqs)
16-
end
3+
const KEEP = typemin(Int)
174

18-
# Note that we reduce parameters, too
19-
# i.e. `2param = 3` will be reduced away
20-
isvar(s) = s isa Sym ? true :
21-
istree(s) ? isvar(operation(s)) :
22-
false
23-
24-
function get_α_x(αx)
25-
if isvar(αx)
26-
return 1, αx
27-
elseif istree(αx) && operation(αx) === (*)
28-
args = arguments(αx)
29-
nums = []
30-
syms = []
31-
for arg in args
32-
isvar(arg) ? push!(syms, arg) : push!(nums, arg)
33-
end
34-
35-
if length(syms) == 1
36-
return prod(nums), syms[1]
37-
end
38-
else
39-
return nothing
40-
end
41-
end
42-
43-
function is_univariate_expr(ex, iv)
44-
count = 0
45-
for var in vars(ex)
46-
if !isequal(iv, var) && !isparameter(var)
47-
count += 1
48-
count > 1 && return false
49-
end
50-
end
51-
return count <= 1
52-
end
53-
54-
function is_sub_candidate(ex, iv, conservative)
55-
conservative || return true
56-
isvar(ex) || ex isa Number || is_univariate_expr(ex, iv)
57-
end
58-
59-
function maybe_alias(lhs, rhs, diff_vars, iv, conservative)
60-
is_sub_candidate(rhs, iv, conservative) || return false, nothing
61-
62-
res_left = get_α_x(lhs)
63-
if res_left !== nothing && !(res_left[2] in diff_vars)
64-
α, x = res_left
65-
sub = x => _isone(α) ? rhs : rhs / α
66-
return true, sub
67-
else
68-
return false, nothing
69-
end
70-
end
71-
72-
function alias_elimination(sys)
5+
function alias_eliminate_graph(sys)
736
sys = flatten(sys)
747
s = get_structure(sys)
758
if !(s isa SystemStructure)
769
sys = initialize_system_structure(sys)
7710
s = structure(sys)
7811
end
79-
iv = independent_variable(sys)
80-
eqs = equations(sys)
81-
diff_vars = filter(!isnothing, map(eqs) do eq
82-
if isdiffeq(eq)
83-
arguments(eq.lhs)[1]
84-
else
85-
nothing
86-
end
87-
end) |> Set
88-
89-
deps = Set()
90-
subs = Pair[]
91-
neweqs = Equation[]; sizehint!(neweqs, length(eqs))
9212

93-
for (i, eq) in enumerate(eqs)
94-
# only substitute when the variable is algebraic
95-
if isdiffeq(eq)
96-
push!(neweqs, eq)
97-
continue
98-
end
99-
100-
# `α x = rhs` => `x = rhs / α`
101-
ma, sub = maybe_alias(eq.lhs, eq.rhs, diff_vars, iv, conservative)
102-
if !ma
103-
# `lhs = β y` => `y = lhs / β`
104-
ma, sub = maybe_alias(eq.rhs, eq.lhs, diff_vars, iv, conservative)
105-
end
106-
107-
isalias = false
108-
if ma
109-
l, r = sub
110-
# alias equations shouldn't introduce cycles
111-
if !(l in deps) && isempty(intersect(deps, vars(r)))
112-
push!(deps, l)
113-
push!(subs, sub)
114-
isalias = true
115-
end
116-
end
117-
118-
if !isalias
119-
neweq = _iszero(eq.lhs) ? eq : 0 ~ eq.rhs - eq.lhs
120-
push!(neweqs, neweq)
121-
end
122-
end
13+
@unpack graph, varassoc = s
12314

124-
alias_vars = first.(subs)
125-
sts = states(sys)
126-
fullsts = vcat(map(eq->eq.lhs, observed(sys)), sts, parameters(sys))
127-
alias_eqs = topsort_equations(alias_vars .~ last.(subs), fullsts)
128-
newstates = setdiff(sts, alias_vars)
129-
130-
@set! sys.eqs = substitute_aliases(neweqs, Dict(subs))
131-
@set! sys.states = newstates
132-
@set! sys.observed = [observed(sys); alias_eqs]
133-
return
134-
end
135-
136-
137-
function alias_elimination_2(sys)
138-
sys = flatten(sys)
139-
s = get_structure(sys)
140-
if !(s isa SystemStructure)
141-
sys = initialize_system_structure(sys)
142-
s = structure(sys)
143-
end
144-
find_solvables!(sys)
145-
@unpack graph, solvable_graph, is_linear_equations, varassoc = s
15+
is_linear_equations, eadj, cadj = find_linear_equations(sys)
16+
old_cadj = map(copy, cadj)
14617

14718
is_not_potential_state = iszero.(varassoc)
14819
is_linear_variables = copy(is_not_potential_state)
@@ -155,29 +26,74 @@ function alias_elimination_2(sys)
15526

15627
linear_equations = findall(is_linear_equations)
15728

158-
offset = 1
159-
coeffs = solvable_graph.metadata
160-
old_coeffs = map(copy, coeffs)
161-
fadj = solvable_graph.fadjlist
16229

16330
rank1 = bareiss!(
164-
(fadj, coeffs),
165-
old_coeffs, linear_equations, is_linear_variables, offset
31+
(eadg, cadj),
32+
old_cadj, linear_equations, is_linear_variables, 1
16633
)
16734

168-
v_solved = [fadj[i][1] for i in 1:rank1]
169-
v_null = setdiff(solvable_variables, v_solved)
170-
n_null_vars = length(v_null)
35+
v_solved = [eadg[i][1] for i in 1:rank1]
36+
v_eliminated = setdiff(solvable_variables, v_solved)
37+
n_null_vars = length(v_eliminated)
17138

17239
v_types = fill(KEEP, ndsts(graph))
173-
for v in v_null
40+
for v in v_eliminated
17441
v_types[v] = 0
17542
end
17643

17744
rank2 = bareiss!(
178-
(fadj, coeffs),
179-
old_coeffs, linear_equations, is_not_potential_state, offset
45+
(eadg, cadj),
46+
old_cadj, linear_equations, is_not_potential_state, rank1+1
18047
)
48+
49+
rank3 = bareiss!(
50+
(eadg, cadj),
51+
old_cadj, linear_equations, nothing, rank2+1
52+
)
53+
54+
# kind of like the backward substitution
55+
for ei in reverse(1:rank2)
56+
locally_structure_simplify!(
57+
(eadg[ei], cadj[ei]),
58+
invvarassoc, v_eliminated, v_types
59+
)
60+
end
61+
62+
reduced = false
63+
for ei in 1:rank2
64+
if length(cadj[ei]) > length(old_cadj[ei])
65+
cadj[ei] = old_cadj[ei]
66+
else
67+
cadj[ei] = eadg[linear_equations[ei]]
68+
reduced |= locally_structure_simplify!(
69+
(eadg[ei], cadj[ei]),
70+
invvarassoc, v_eliminated, v_types
71+
)
72+
end
73+
end
74+
75+
while reduced
76+
for ei in 1:rank2
77+
if !isempty(eadg[ei])
78+
reduced |= locally_structure_simplify!(
79+
(eadg[ei], cadj[ei]),
80+
invvarassoc, v_eliminated, v_types
81+
)
82+
reduced && break # go back to the begining of equations
83+
end
84+
end
85+
end
86+
87+
for ei in rank2+1:length(linear_equations)
88+
eadg[ei] = old_cadj[ei]
89+
end
90+
91+
for (ei, e) in enumerate(linear_equations)
92+
graph.eadglist[e] = eadg[ei]
93+
end
94+
95+
degenerate_equations = rank3 < length(linear_equations) ? linear_equations[rank3+1:end] : Int[]
96+
return v_eliminated, v_types, n_null_vars, degenerate_equations
18197
end
18298

18399
iszeroterm(v_types, v) = v_types[v] == 0
@@ -188,7 +104,7 @@ negalias(v_types, v) = -v_types[v]
188104

189105
function locally_structure_simplify!(
190106
(vars, coeffs),
191-
invvarassoc, v_null, v_types
107+
invvarassoc, v_eliminated, v_types
192108
)
193109
while length(vars) > 1 && any(!isequal(KEEP), (v_types[v] in @view vars[2:end]))
194110
for vj in 2:length(vars)
@@ -238,18 +154,18 @@ function locally_structure_simplify!(
238154
v = first(vars)
239155
if invvarassoc[v] == 0
240156
if length(nvars) == 1
241-
push!(v_null, v)
157+
push!(v_eliminated, v)
242158
v_types[v] = 0
243159
empty!(vars); empty!(coeffs)
244160
return true
245161
elseif length(vars) == 2 && abs(coeffs[1]) == abs(coeffs[2])
246162
if (coeffs[1] > 0 && coeffs[2] < 0) || (coeffs[1] < 0 && coeffs[2] > 0)
247163
# positive alias
248-
push!(v_null, v)
164+
push!(v_eliminated, v)
249165
v_types[v] = vars[2]
250166
else
251167
# negative alias
252-
push!(v_null, v)
168+
push!(v_eliminated, v)
253169
v_types[v] = -vars[2]
254170
end
255171
empty!(vars); empty!(coeffs)
@@ -265,11 +181,11 @@ $(SIGNATURES)
265181
Use Bareiss algorithm to compute the nullspace of an integer matrix exactly.
266182
"""
267183
function bareiss!(
268-
(fadj, coeffs),
269-
old_coeffs, linear_equations, is_linear_variables, offset
184+
(eadg, cadj),
185+
old_cadj, linear_equations, is_linear_variables, offset
270186
)
271187
m = nsrcs(solvable_graph)
272-
# v = fadj[ei][vj]
188+
# v = eadg[ei][vj]
273189
v = ei = vj = 0
274190
pivot = last_pivot = 1
275191
tmp_incidence = Int[]
@@ -293,14 +209,14 @@ function bareiss!(
293209
end
294210

295211
if vj > 0 # has a pivot
296-
pivot = coeffs[ei][vj]
297-
deleteat!(coeffs[ei] , vj)
298-
v = fadj[ei][vj]
299-
deleteat!(fadj[ei], vj)
212+
pivot = cadj[ei][vj]
213+
deleteat!(cadj[ei] , vj)
214+
v = eadg[ei][vj]
215+
deleteat!(eadg[ei], vj)
300216
if ei != k
301-
swap!(coeffs, ei, k)
302-
swap!(old_coeffs, ei, k)
303-
swap!(fadj, ei, k)
217+
swap!(cadj, ei, k)
218+
swap!(old_cadj, ei, k)
219+
swap!(eadg, ei, k)
304220
swap!(linear_equations, ei, k)
305221
end
306222
else # rank deficient
@@ -310,22 +226,22 @@ function bareiss!(
310226
for ei in k+1
311227
# elimate `v`
312228
coeff = 0
313-
vars = fadj[ei]
229+
vars = eadg[ei]
314230
vj = findfirst(isequal(v), vars)
315231
if vj === nothing # `v` is not in in `e`
316232
continue
317233
else # remove `v`
318-
coeff = coeffs[ei][vj]
319-
deleteat!(coeffs[ei], vj)
320-
deleteat!(fadj[ei], vj)
234+
coeff = cadj[ei][vj]
235+
deleteat!(cadj[ei], vj)
236+
deleteat!(eadg[ei], vj)
321237
end
322238

323239
# the pivot row
324-
kvars = fadj[k]
325-
kcoeffs = coeffs[k]
240+
kvars = eadg[k]
241+
kcoeffs = cadj[k]
326242
# the elimination target
327-
ivars = fadj[ei]
328-
icoeffs = coeffs[ei]
243+
ivars = eadg[ei]
244+
icoeffs = cadj[ei]
329245

330246
empty!(tmp_incidence)
331247
empty!(tmp_coeffs)
@@ -342,13 +258,13 @@ function bareiss!(
342258
end
343259
end
344260

345-
fadj[ei], tmp_incidence = tmp_incidence, fadj[ei]
346-
coeffs[ei], tmp_coeffs = tmp_coeffs, coeffs[ei]
261+
eadg[ei], tmp_incidence = tmp_incidence, eadg[ei]
262+
cadj[ei], tmp_coeffs = tmp_coeffs, cadj[ei]
347263
end
348264
last_pivot = pivot
349265
# add `v` in the front of the `k`-th equation
350-
pushfirst!(fadj[k], v)
351-
pushfirst!(coeffs[k], pivot)
266+
pushfirst!(eadg[k], v)
267+
pushfirst!(cadj[k], pivot)
352268
end
353269

354270
return m # fully ranked
@@ -372,14 +288,14 @@ the `constraint`.
372288
@inline function find_first_linear_variable(
373289
solvable_graph,
374290
range,
375-
is_linear_variables,
291+
mask,
376292
constraint,
377293
)
378294
for i in range
379295
vertices = 𝑠vertices(solvable_graph, i)
380296
if constraint(length(vertices))
381297
for (j, v) in enumerate(vertices)
382-
is_linear_variables[v] && return i, j
298+
(mask === nothing || mask[v]) && return i, j
383299
end
384300
end
385301
end
@@ -464,3 +380,18 @@ function observed2graph(eqs, states)
464380

465381
return graph, assigns
466382
end
383+
384+
function fixpoint_sub(x, dict)
385+
y = substitute(x, dict)
386+
while !isequal(x, y)
387+
y = x
388+
x = substitute(y, dict)
389+
end
390+
391+
return x
392+
end
393+
394+
function substitute_aliases(eqs, dict)
395+
sub = Base.Fix2(fixpoint_sub, dict)
396+
map(eq->eq.lhs ~ sub(eq.rhs), eqs)
397+
end

0 commit comments

Comments
 (0)