|
21 | 21 |
|
22 | 22 | function walk_to_root(ag, var_to_diff, v::Integer)
|
23 | 23 | diff_to_var = invview(var_to_diff)
|
24 |
| - has_branch = true |
25 |
| - lowest_v = v |
26 |
| - while has_branch |
27 |
| - v′::Union{Nothing, Int} = v |
28 |
| - while (v′ = diff_to_var[v]) !== nothing |
29 |
| - v = v′ |
30 |
| - end |
31 |
| - # `v` is now not differentiated in the current chain. |
32 |
| - # Now we visit the current chain. |
33 |
| - lowest_v = v |
34 |
| - while (v′ = var_to_diff[v]) !== nothing |
35 |
| - v = v′ |
36 |
| - next_v = get(ag, v, nothing) |
37 |
| - next_v === nothing || (v = next_v; continue) |
38 |
| - end |
39 |
| - has_branch = false |
| 24 | + |
| 25 | + v′::Union{Nothing, Int} = v |
| 26 | + @label HAS_BRANCH |
| 27 | + while (v′ = diff_to_var[v]) !== nothing |
| 28 | + v = v′ |
| 29 | + end |
| 30 | + # `v` is now not differentiated in the current chain. |
| 31 | + # Now we recursively walk to root variable's chain. |
| 32 | + while true |
| 33 | + next_v = get(ag, v, nothing) |
| 34 | + next_v === nothing || (v = next_v[2]; @goto HAS_BRANCH) |
| 35 | + (v′ = var_to_diff[v]) === nothing && break |
| 36 | + v = v′ |
| 37 | + end |
| 38 | + |
| 39 | + # Descend to the root from the chain |
| 40 | + while (v′ = diff_to_var[v]) !== nothing |
| 41 | + v = v′ |
40 | 42 | end
|
41 |
| - lowest_v |
| 43 | + v |
42 | 44 | end
|
43 | 45 |
|
44 | 46 | function visit_differential_aliases!(ag, level_to_var, processed, invag, var_to_diff, v, level=0)
|
@@ -168,16 +170,17 @@ function alias_elimination(sys)
|
168 | 170 | while (v = var_to_diff[v]) !== nothing
|
169 | 171 | if !(v in keys(ag))
|
170 | 172 | has_higher_order = true
|
| 173 | + break |
171 | 174 | end
|
172 | 175 | end
|
173 | 176 | if !has_higher_order
|
174 | 177 | rhs = fullvars[j]
|
175 | 178 | push!(eqs, subs[fullvars[j]] ~ rhs)
|
176 |
| - push!(newstates, rhs) |
| 179 | + diff_to_var[j] === nothing && push!(newstates, rhs) |
177 | 180 | end
|
178 | 181 | end
|
179 | 182 | else
|
180 |
| - isdervar(state.structure, j) || push!(newstates, fullvars[j]) |
| 183 | + diff_to_var[j] === nothing && push!(newstates, fullvars[j]) |
181 | 184 | end
|
182 | 185 | end
|
183 | 186 |
|
|
0 commit comments