1
1
using SymbolicUtils: Rewriters
2
2
3
- function fixpoint_sub (x, dict)
4
- y = substitute (x, dict)
5
- while ! isequal (x, y)
6
- y = x
7
- x = substitute (y, dict)
8
- end
9
-
10
- return x
11
- end
12
-
13
- function substitute_aliases (eqs, dict)
14
- sub = Base. Fix2 (fixpoint_sub, dict)
15
- map (eq-> eq. lhs ~ sub (eq. rhs), eqs)
16
- end
3
+ const KEEP = typemin (Int)
17
4
18
- # Note that we reduce parameters, too
19
- # i.e. `2param = 3` will be reduced away
20
- isvar (s) = s isa Sym ? true :
21
- istree (s) ? isvar (operation (s)) :
22
- false
23
-
24
- function get_α_x (αx)
25
- if isvar (αx)
26
- return 1 , αx
27
- elseif istree (αx) && operation (αx) === (* )
28
- args = arguments (αx)
29
- nums = []
30
- syms = []
31
- for arg in args
32
- isvar (arg) ? push! (syms, arg) : push! (nums, arg)
33
- end
34
-
35
- if length (syms) == 1
36
- return prod (nums), syms[1 ]
37
- end
38
- else
39
- return nothing
40
- end
41
- end
42
-
43
- function is_univariate_expr (ex, iv)
44
- count = 0
45
- for var in vars (ex)
46
- if ! isequal (iv, var) && ! isparameter (var)
47
- count += 1
48
- count > 1 && return false
49
- end
50
- end
51
- return count <= 1
52
- end
53
-
54
- function is_sub_candidate (ex, iv, conservative)
55
- conservative || return true
56
- isvar (ex) || ex isa Number || is_univariate_expr (ex, iv)
57
- end
58
-
59
- function maybe_alias (lhs, rhs, diff_vars, iv, conservative)
60
- is_sub_candidate (rhs, iv, conservative) || return false , nothing
61
-
62
- res_left = get_α_x (lhs)
63
- if res_left != = nothing && ! (res_left[2 ] in diff_vars)
64
- α, x = res_left
65
- sub = x => _isone (α) ? rhs : rhs / α
66
- return true , sub
67
- else
68
- return false , nothing
69
- end
70
- end
71
-
72
- function alias_elimination (sys)
5
+ function alias_eliminate_graph (sys)
73
6
sys = flatten (sys)
74
7
s = get_structure (sys)
75
8
if ! (s isa SystemStructure)
76
9
sys = initialize_system_structure (sys)
77
10
s = structure (sys)
78
11
end
79
- iv = independent_variable (sys)
80
- eqs = equations (sys)
81
- diff_vars = filter (! isnothing, map (eqs) do eq
82
- if isdiffeq (eq)
83
- arguments (eq. lhs)[1 ]
84
- else
85
- nothing
86
- end
87
- end ) |> Set
88
-
89
- deps = Set ()
90
- subs = Pair[]
91
- neweqs = Equation[]; sizehint! (neweqs, length (eqs))
92
12
93
- for (i, eq) in enumerate (eqs)
94
- # only substitute when the variable is algebraic
95
- if isdiffeq (eq)
96
- push! (neweqs, eq)
97
- continue
98
- end
99
-
100
- # `α x = rhs` => `x = rhs / α`
101
- ma, sub = maybe_alias (eq. lhs, eq. rhs, diff_vars, iv, conservative)
102
- if ! ma
103
- # `lhs = β y` => `y = lhs / β`
104
- ma, sub = maybe_alias (eq. rhs, eq. lhs, diff_vars, iv, conservative)
105
- end
106
-
107
- isalias = false
108
- if ma
109
- l, r = sub
110
- # alias equations shouldn't introduce cycles
111
- if ! (l in deps) && isempty (intersect (deps, vars (r)))
112
- push! (deps, l)
113
- push! (subs, sub)
114
- isalias = true
115
- end
116
- end
117
-
118
- if ! isalias
119
- neweq = _iszero (eq. lhs) ? eq : 0 ~ eq. rhs - eq. lhs
120
- push! (neweqs, neweq)
121
- end
122
- end
13
+ @unpack graph, varassoc = s
123
14
124
- alias_vars = first .(subs)
125
- sts = states (sys)
126
- fullsts = vcat (map (eq-> eq. lhs, observed (sys)), sts, parameters (sys))
127
- alias_eqs = topsort_equations (alias_vars .~ last .(subs), fullsts)
128
- newstates = setdiff (sts, alias_vars)
129
-
130
- @set! sys. eqs = substitute_aliases (neweqs, Dict (subs))
131
- @set! sys. states = newstates
132
- @set! sys. observed = [observed (sys); alias_eqs]
133
- return
134
- end
135
-
136
-
137
- function alias_elimination_2 (sys)
138
- sys = flatten (sys)
139
- s = get_structure (sys)
140
- if ! (s isa SystemStructure)
141
- sys = initialize_system_structure (sys)
142
- s = structure (sys)
143
- end
144
- find_solvables! (sys)
145
- @unpack graph, solvable_graph, is_linear_equations, varassoc = s
15
+ is_linear_equations, eadj, cadj = find_linear_equations (sys)
16
+ old_cadj = map (copy, cadj)
146
17
147
18
is_not_potential_state = iszero .(varassoc)
148
19
is_linear_variables = copy (is_not_potential_state)
@@ -155,29 +26,74 @@ function alias_elimination_2(sys)
155
26
156
27
linear_equations = findall (is_linear_equations)
157
28
158
- offset = 1
159
- coeffs = solvable_graph. metadata
160
- old_coeffs = map (copy, coeffs)
161
- fadj = solvable_graph. fadjlist
162
29
163
30
rank1 = bareiss! (
164
- (fadj, coeffs ),
165
- old_coeffs , linear_equations, is_linear_variables, offset
31
+ (eadg, cadj ),
32
+ old_cadj , linear_equations, is_linear_variables, 1
166
33
)
167
34
168
- v_solved = [fadj [i][1 ] for i in 1 : rank1]
169
- v_null = setdiff (solvable_variables, v_solved)
170
- n_null_vars = length (v_null )
35
+ v_solved = [eadg [i][1 ] for i in 1 : rank1]
36
+ v_eliminated = setdiff (solvable_variables, v_solved)
37
+ n_null_vars = length (v_eliminated )
171
38
172
39
v_types = fill (KEEP, ndsts (graph))
173
- for v in v_null
40
+ for v in v_eliminated
174
41
v_types[v] = 0
175
42
end
176
43
177
44
rank2 = bareiss! (
178
- (fadj, coeffs ),
179
- old_coeffs , linear_equations, is_not_potential_state, offset
45
+ (eadg, cadj ),
46
+ old_cadj , linear_equations, is_not_potential_state, rank1 + 1
180
47
)
48
+
49
+ rank3 = bareiss! (
50
+ (eadg, cadj),
51
+ old_cadj, linear_equations, nothing , rank2+ 1
52
+ )
53
+
54
+ # kind of like the backward substitution
55
+ for ei in reverse (1 : rank2)
56
+ locally_structure_simplify! (
57
+ (eadg[ei], cadj[ei]),
58
+ invvarassoc, v_eliminated, v_types
59
+ )
60
+ end
61
+
62
+ reduced = false
63
+ for ei in 1 : rank2
64
+ if length (cadj[ei]) > length (old_cadj[ei])
65
+ cadj[ei] = old_cadj[ei]
66
+ else
67
+ cadj[ei] = eadg[linear_equations[ei]]
68
+ reduced |= locally_structure_simplify! (
69
+ (eadg[ei], cadj[ei]),
70
+ invvarassoc, v_eliminated, v_types
71
+ )
72
+ end
73
+ end
74
+
75
+ while reduced
76
+ for ei in 1 : rank2
77
+ if ! isempty (eadg[ei])
78
+ reduced |= locally_structure_simplify! (
79
+ (eadg[ei], cadj[ei]),
80
+ invvarassoc, v_eliminated, v_types
81
+ )
82
+ reduced && break # go back to the begining of equations
83
+ end
84
+ end
85
+ end
86
+
87
+ for ei in rank2+ 1 : length (linear_equations)
88
+ eadg[ei] = old_cadj[ei]
89
+ end
90
+
91
+ for (ei, e) in enumerate (linear_equations)
92
+ graph. eadglist[e] = eadg[ei]
93
+ end
94
+
95
+ degenerate_equations = rank3 < length (linear_equations) ? linear_equations[rank3+ 1 : end ] : Int[]
96
+ return v_eliminated, v_types, n_null_vars, degenerate_equations
181
97
end
182
98
183
99
iszeroterm (v_types, v) = v_types[v] == 0
@@ -188,7 +104,7 @@ negalias(v_types, v) = -v_types[v]
188
104
189
105
function locally_structure_simplify! (
190
106
(vars, coeffs),
191
- invvarassoc, v_null , v_types
107
+ invvarassoc, v_eliminated , v_types
192
108
)
193
109
while length (vars) > 1 && any (! isequal (KEEP), (v_types[v] in @view vars[2 : end ]))
194
110
for vj in 2 : length (vars)
@@ -238,18 +154,18 @@ function locally_structure_simplify!(
238
154
v = first (vars)
239
155
if invvarassoc[v] == 0
240
156
if length (nvars) == 1
241
- push! (v_null , v)
157
+ push! (v_eliminated , v)
242
158
v_types[v] = 0
243
159
empty! (vars); empty! (coeffs)
244
160
return true
245
161
elseif length (vars) == 2 && abs (coeffs[1 ]) == abs (coeffs[2 ])
246
162
if (coeffs[1 ] > 0 && coeffs[2 ] < 0 ) || (coeffs[1 ] < 0 && coeffs[2 ] > 0 )
247
163
# positive alias
248
- push! (v_null , v)
164
+ push! (v_eliminated , v)
249
165
v_types[v] = vars[2 ]
250
166
else
251
167
# negative alias
252
- push! (v_null , v)
168
+ push! (v_eliminated , v)
253
169
v_types[v] = - vars[2 ]
254
170
end
255
171
empty! (vars); empty! (coeffs)
@@ -265,11 +181,11 @@ $(SIGNATURES)
265
181
Use Bareiss algorithm to compute the nullspace of an integer matrix exactly.
266
182
"""
267
183
function bareiss! (
268
- (fadj, coeffs ),
269
- old_coeffs , linear_equations, is_linear_variables, offset
184
+ (eadg, cadj ),
185
+ old_cadj , linear_equations, is_linear_variables, offset
270
186
)
271
187
m = nsrcs (solvable_graph)
272
- # v = fadj [ei][vj]
188
+ # v = eadg [ei][vj]
273
189
v = ei = vj = 0
274
190
pivot = last_pivot = 1
275
191
tmp_incidence = Int[]
@@ -293,14 +209,14 @@ function bareiss!(
293
209
end
294
210
295
211
if vj > 0 # has a pivot
296
- pivot = coeffs [ei][vj]
297
- deleteat! (coeffs [ei] , vj)
298
- v = fadj [ei][vj]
299
- deleteat! (fadj [ei], vj)
212
+ pivot = cadj [ei][vj]
213
+ deleteat! (cadj [ei] , vj)
214
+ v = eadg [ei][vj]
215
+ deleteat! (eadg [ei], vj)
300
216
if ei != k
301
- swap! (coeffs , ei, k)
302
- swap! (old_coeffs , ei, k)
303
- swap! (fadj , ei, k)
217
+ swap! (cadj , ei, k)
218
+ swap! (old_cadj , ei, k)
219
+ swap! (eadg , ei, k)
304
220
swap! (linear_equations, ei, k)
305
221
end
306
222
else # rank deficient
@@ -310,22 +226,22 @@ function bareiss!(
310
226
for ei in k+ 1
311
227
# elimate `v`
312
228
coeff = 0
313
- vars = fadj [ei]
229
+ vars = eadg [ei]
314
230
vj = findfirst (isequal (v), vars)
315
231
if vj === nothing # `v` is not in in `e`
316
232
continue
317
233
else # remove `v`
318
- coeff = coeffs [ei][vj]
319
- deleteat! (coeffs [ei], vj)
320
- deleteat! (fadj [ei], vj)
234
+ coeff = cadj [ei][vj]
235
+ deleteat! (cadj [ei], vj)
236
+ deleteat! (eadg [ei], vj)
321
237
end
322
238
323
239
# the pivot row
324
- kvars = fadj [k]
325
- kcoeffs = coeffs [k]
240
+ kvars = eadg [k]
241
+ kcoeffs = cadj [k]
326
242
# the elimination target
327
- ivars = fadj [ei]
328
- icoeffs = coeffs [ei]
243
+ ivars = eadg [ei]
244
+ icoeffs = cadj [ei]
329
245
330
246
empty! (tmp_incidence)
331
247
empty! (tmp_coeffs)
@@ -342,13 +258,13 @@ function bareiss!(
342
258
end
343
259
end
344
260
345
- fadj [ei], tmp_incidence = tmp_incidence, fadj [ei]
346
- coeffs [ei], tmp_coeffs = tmp_coeffs, coeffs [ei]
261
+ eadg [ei], tmp_incidence = tmp_incidence, eadg [ei]
262
+ cadj [ei], tmp_coeffs = tmp_coeffs, cadj [ei]
347
263
end
348
264
last_pivot = pivot
349
265
# add `v` in the front of the `k`-th equation
350
- pushfirst! (fadj [k], v)
351
- pushfirst! (coeffs [k], pivot)
266
+ pushfirst! (eadg [k], v)
267
+ pushfirst! (cadj [k], pivot)
352
268
end
353
269
354
270
return m # fully ranked
@@ -372,14 +288,14 @@ the `constraint`.
372
288
@inline function find_first_linear_variable (
373
289
solvable_graph,
374
290
range,
375
- is_linear_variables ,
291
+ mask ,
376
292
constraint,
377
293
)
378
294
for i in range
379
295
vertices = 𝑠vertices (solvable_graph, i)
380
296
if constraint (length (vertices))
381
297
for (j, v) in enumerate (vertices)
382
- is_linear_variables [v] && return i, j
298
+ (mask === nothing || mask [v]) && return i, j
383
299
end
384
300
end
385
301
end
@@ -464,3 +380,18 @@ function observed2graph(eqs, states)
464
380
465
381
return graph, assigns
466
382
end
383
+
384
+ function fixpoint_sub (x, dict)
385
+ y = substitute (x, dict)
386
+ while ! isequal (x, y)
387
+ y = x
388
+ x = substitute (y, dict)
389
+ end
390
+
391
+ return x
392
+ end
393
+
394
+ function substitute_aliases (eqs, dict)
395
+ sub = Base. Fix2 (fixpoint_sub, dict)
396
+ map (eq-> eq. lhs ~ sub (eq. rhs), eqs)
397
+ end
0 commit comments