Skip to content

Commit 891f58c

Browse files
committed
Fix walk_to_root
1 parent 63a9493 commit 891f58c

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

src/systems/alias_elimination.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,26 @@ end
2121

2222
function walk_to_root(ag, var_to_diff, v::Integer)
2323
diff_to_var = invview(var_to_diff)
24-
has_branch = true
25-
lowest_v = v
26-
while has_branch
27-
v′::Union{Nothing, Int} = v
28-
while (v′ = diff_to_var[v]) !== nothing
29-
v = v′
30-
end
31-
# `v` is now not differentiated in the current chain.
32-
# Now we visit the current chain.
33-
lowest_v = v
34-
while (v′ = var_to_diff[v]) !== nothing
35-
v = v′
36-
next_v = get(ag, v, nothing)
37-
next_v === nothing || (v = next_v; continue)
38-
end
39-
has_branch = false
24+
25+
v′::Union{Nothing, Int} = v
26+
@label HAS_BRANCH
27+
while (v′ = diff_to_var[v]) !== nothing
28+
v = v′
29+
end
30+
# `v` is now not differentiated in the current chain.
31+
# Now we recursively walk to root variable's chain.
32+
while true
33+
next_v = get(ag, v, nothing)
34+
next_v === nothing || (v = next_v[2]; @goto HAS_BRANCH)
35+
(v′ = var_to_diff[v]) === nothing && break
36+
v = v′
37+
end
38+
39+
# Descend to the root from the chain
40+
while (v′ = diff_to_var[v]) !== nothing
41+
v = v′
4042
end
41-
lowest_v
43+
v
4244
end
4345

4446
function visit_differential_aliases!(ag, level_to_var, processed, invag, var_to_diff, v, level=0)
@@ -168,16 +170,17 @@ function alias_elimination(sys)
168170
while (v = var_to_diff[v]) !== nothing
169171
if !(v in keys(ag))
170172
has_higher_order = true
173+
break
171174
end
172175
end
173176
if !has_higher_order
174177
rhs = fullvars[j]
175178
push!(eqs, subs[fullvars[j]] ~ rhs)
176-
push!(newstates, rhs)
179+
diff_to_var[j] === nothing && push!(newstates, rhs)
177180
end
178181
end
179182
else
180-
isdervar(state.structure, j) || push!(newstates, fullvars[j])
183+
diff_to_var[j] === nothing && push!(newstates, fullvars[j])
181184
end
182185
end
183186

0 commit comments

Comments
 (0)