Skip to content

Commit 1a07aba

Browse files
committed
Don't assume direction of ag so that we have acyclic guarantee
1 parent cdf30dc commit 1a07aba

File tree

2 files changed

+53
-43
lines changed

2 files changed

+53
-43
lines changed

src/systems/alias_elimination.jl

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,56 +19,53 @@ function aag_bareiss(sys::AbstractSystem)
1919
return aag_bareiss!(state.structure.graph, complete(state.structure.var_to_diff), mm)
2020
end
2121

22-
function walk_to_root(ag, var_to_diff, v::Integer)
23-
diff_to_var = invview(var_to_diff)
22+
function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true)) where descend
23+
g = descend ? invview(var_to_diff) : var_to_diff
24+
while (v′ = g[v]) !== nothing
25+
v = v′
26+
if level !== nothing
27+
descend ? (level -= 1) : (level += 1)
28+
end
29+
end
30+
level === nothing ? v : (v => level)
31+
end
2432

33+
function neighbor_branches!(visited, (ag, invag), var_to_diff, v, level = 0)
34+
ns = Pair{Int, Int}[]
35+
visited[v] && return ns
2536
v′::Union{Nothing, Int} = v
26-
@label HAS_BRANCH
37+
diff_to_var = invview(var_to_diff)
2738
while (v′ = diff_to_var[v]) !== nothing
2839
v = v′
40+
level -= 1
2941
end
30-
# `v` is now not differentiated in the current chain.
31-
# Now we recursively walk to root variable's chain.
3242
while true
33-
next_v = get(ag, v, nothing)
34-
next_v === nothing || (v = next_v[2]; @goto HAS_BRANCH)
43+
if (_n = get(ag, v, nothing)) !== nothing
44+
n = _n[2]
45+
visited[n] || push!(ns, n => level)
46+
end
47+
for n in neighbors(invag, v)
48+
visited[n] || push!(ns, n => level)
49+
end
50+
visited[v] = true
3551
(v′ = var_to_diff[v]) === nothing && break
3652
v = v′
53+
level += 1
3754
end
38-
39-
# Descend to the root from the chain
40-
while (v′ = diff_to_var[v]) !== nothing
41-
v = v′
42-
end
43-
v
55+
ns
4456
end
4557

46-
function visit_differential_aliases!(ag, level_to_var, processed, invag, var_to_diff, v, level=0)
47-
processed[v] && return nothing
48-
for n in neighbors(invag, v)
49-
# TODO: we currently only handle `coeff == 1`
50-
if isone(ag[n][1])
51-
visit_differential_aliases!(ag, level_to_var, processed, invag, var_to_diff, n, level)
52-
end
53-
end
54-
# Note that we don't need to update `invag`
55-
if 1 <= level + 1 <= length(level_to_var)
56-
root_var = level_to_var[level + 1]
57-
if v != root_var
58-
ag[v] = 1 => root_var
58+
function walk_to_root!(visited, ags, var_to_diff, v::Integer, level = 0)
59+
brs = neighbor_branches!(visited, ags, var_to_diff, v, level)
60+
min_var_level = v => level
61+
isempty(brs) && return extreme_var(var_to_diff, min_var_level...)
62+
for (x, lv) in brs
63+
x, lv = walk_to_root!(visited, ags, var_to_diff, x, lv)
64+
if min_var_level[2] > lv
65+
min_var_level = x => lv
5966
end
60-
else
61-
@assert length(level_to_var) == level
62-
push!(level_to_var, v)
6367
end
64-
processed[v] = true
65-
if (dv = var_to_diff[v]) !== nothing
66-
visit_differential_aliases!(ag, level_to_var, processed, invag, var_to_diff, dv, level + 1)
67-
end
68-
if (iv = invview(var_to_diff)[v]) !== nothing
69-
visit_differential_aliases!(ag, level_to_var, processed, invag, var_to_diff, iv, level - 1)
70-
end
71-
return nothing
68+
return extreme_var(var_to_diff, min_var_level...)
7269
end
7370

7471
function alias_elimination(sys)
@@ -86,7 +83,7 @@ function alias_elimination(sys)
8683
# ⇓ ⇑
8784
# ⇓ x_t --> D(x_t)
8885
# ⇓ |---------------|
89-
# z --> D(z) --> D(D(z)) |--> D(D(D(z))) |
86+
# z --> D(z) --> D(D(z)) |--> D(D(D(z))) |
9087
# ⇑ |---------------|
9188
# k --> D(k)
9289
#
@@ -102,26 +99,38 @@ function alias_elimination(sys)
10299
# with a tie breaking strategy. The root variable (in this case `z`) is
103100
# always uniquely determined. Thus, the result is well-defined.
104101
D = has_iv(sys) ? Differential(get_iv(sys)) : nothing
102+
nvars = length(fullvars)
105103
diff_to_var = invview(var_to_diff)
106-
invag = SimpleDiGraph(length(fullvars))
104+
invag = SimpleDiGraph(nvars)
107105
for (v, (coeff, alias)) in pairs(ag)
108106
iszero(coeff) && continue
109107
add_edge!(invag, alias, v)
110108
end
111-
processed = falses(length(var_to_diff))
109+
Main._a[] = ag, invag
110+
processed = falses(nvars)
111+
visited = falses(nvars)
112+
newag = AliasGraph(nvars)
112113
for (v, dv) in enumerate(var_to_diff)
113114
processed[v] && continue
114115
(dv === nothing && diff_to_var[v] === nothing) && continue
115116

116-
r = walk_to_root(ag, var_to_diff, v)
117+
# TODO: use an iterator, and get a relative level vector for `processed`
118+
# variabels.
119+
r, lv = walk_to_root!(processed, (ag, invag), var_to_diff, v)
120+
#lv = extreme_var(var_to_diff, v, -lv, Val(false))
121+
lv′ = extreme_var(var_to_diff, v, 0, Val(false))[2]
122+
let
123+
sv = fullvars[v]
124+
root = fullvars[r]
125+
@warn "" sv => root level = lv levelv = lv′
126+
end
117127
level_to_var = Int[r]
118128
v′′::Union{Nothing, Int} = v′::Int = r
119129
while (v′′ = var_to_diff[v′]) !== nothing
120130
v′ = v′′
121131
push!(level_to_var, v′)
122132
end
123133
nlevels = length(level_to_var)
124-
visit_differential_aliases!(ag, level_to_var, processed, invag, var_to_diff, r)
125134
if nlevels < (new_nlevels = length(level_to_var))
126135
@assert !(D isa Nothing)
127136
for i in (nlevels + 1):new_nlevels
@@ -502,6 +511,7 @@ function locally_structure_simplify!(adj_row, pivot_var, ag, var_to_diff)
502511
if alias_candidate isa Pair
503512
alias_val, alias_var = alias_candidate
504513
#preferred_var = pivot_var
514+
#=
505515
switch = false # we prefer `alias_var` by default, unless we switch
506516
diff_to_var = invview(var_to_diff)
507517
pivot_var′′::Union{Nothing, Int} = pivot_var′::Int = pivot_var
@@ -525,6 +535,7 @@ function locally_structure_simplify!(adj_row, pivot_var, ag, var_to_diff)
525535
pivot_var, alias_var = alias_var, pivot_var
526536
pivot_val, alias_val = alias_val, pivot_val
527537
end
538+
=#
528539

529540
# `p` is the pivot variable, `a` is the alias variable, `v` and `c` are
530541
# their coefficients.

src/systems/systemstructure.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,6 @@ function linear_subsys_adjmat(state::TransformationState)
348348
cadj = Vector{Int}[]
349349
coeffs = Int[]
350350
for (i, eq) in enumerate(eqs)
351-
isdiffeq(eq) && continue
352351
empty!(coeffs)
353352
linear_term = 0
354353
all_int_vars = true

0 commit comments

Comments
 (0)