Skip to content

Commit b464d72

Browse files
committed
WIP
1 parent 37c6ca8 commit b464d72

File tree

2 files changed

+76
-9
lines changed

2 files changed

+76
-9
lines changed

src/systems/alias_elimination.jl

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,22 @@ end
288288

289289
function tograph(ag::AliasGraph, var_to_diff::DiffGraph)
290290
g = SimpleDiGraph{Int}(length(var_to_diff))
291+
zero_vars = Int[]
291292
for (v, (_, a)) in ag
292-
iszero(a) && continue
293+
if iszero(a)
294+
push!(zero_vars, v)
295+
continue
296+
end
293297
add_edge!(g, v, a)
294298
add_edge!(g, a, v)
295299
end
296300
transitiveclosure!(g)
301+
zero_vars_set = BitSet(zero_vars)
302+
for v in zero_vars
303+
for a in outneighbors(g, v)
304+
push!(zero_vars_set, a)
305+
end
306+
end
297307
# Compute the largest transitive closure that doesn't include any diff
298308
# edges.
299309
og = g
@@ -319,7 +329,7 @@ function tograph(ag::AliasGraph, var_to_diff::DiffGraph)
319329
add_edge!(g, v, dv)
320330
add_edge!(g, dv, v)
321331
end
322-
g, edge_styles
332+
g, zero_vars_set, edge_styles
323333
end
324334

325335
using Graphs.Experimental.Traversals
@@ -780,18 +790,16 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
780790
# Note that since we always prefer the higher differentiated variable and
781791
# with a tie breaking strategy, the root variable (in this case `z`) is
782792
# always uniquely determined. Thus, the result is well-defined.
793+
dag = AliasGraph(nvars) # alias graph for differentiated variables
794+
updated_diff_vars = Int[]
783795
diff_to_var = invview(var_to_diff)
784-
invag = SimpleDiGraph(nvars)
785-
for (v, (coeff, alias)) in pairs(ag)
786-
iszero(coeff) && continue
787-
add_edge!(invag, alias, v)
788-
end
789796
processed = falses(nvars)
790-
g, = tograph(ag, var_to_diff)
797+
g, zero_vars_set = tograph(ag, var_to_diff)
791798
dls = DiffLevelState(g, var_to_diff)
792799
is_diff_edge = let var_to_diff = var_to_diff
793800
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
794801
end
802+
diff_aliases = Vector{Pair{Int, Int}}[]
795803
for (v, dv) in enumerate(var_to_diff)
796804
processed[v] && continue
797805
(dv === nothing && diff_to_var[v] === nothing) && continue
@@ -801,15 +809,67 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
801809
extreme_var(var_to_diff, r, nothing, Val(false),
802810
callback = Base.Fix1(push!, level_to_var))
803811
nlevels = length(level_to_var)
804-
current_coeff_level = Ref((0, 0))
812+
prev_r = -1
813+
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
814+
reach₌ = Pair{Int, Int}[]
815+
r === nothing || for n in neighbors(g, r)
816+
(n == r || is_diff_edge(r, n)) && continue
817+
c = 1
818+
push!(reach₌, c => n)
819+
end
820+
if (n = length(diff_aliases)) >= 2
821+
as = diff_aliases[n-1]
822+
for (c, a) in as
823+
(da = var_to_diff[a]) === nothing && continue
824+
da === r && continue
825+
push!(reach₌, c => da)
826+
end
827+
end
828+
for (c, a) in reach₌
829+
@info fullvars[r] => c * fullvars[a]
830+
end
831+
if r === nothing
832+
# TODO: updated_diff_vars check
833+
isempty(reach₌) && break
834+
dr = first(reach₌)
835+
var_to_diff[prev_r] = dr
836+
push!(updated_diff_vars, prev_r)
837+
prev_r = dr
838+
else
839+
prev_r = r
840+
r = var_to_diff[r]
841+
end
842+
for (c, v) in reach₌
843+
v == prev_r && continue
844+
dag[v] = c => prev_r
845+
end
846+
push!(diff_aliases, reach₌)
847+
end
848+
for v in zero_vars_set
849+
dag[v] = 0
850+
end
851+
@show nlevels
852+
display(diff_aliases)
853+
@assert length(diff_aliases) == nlevels
854+
@show zero_vars_set
855+
856+
# clean up
805857
for v in dls.visited
806858
dls.dists[v] = typemax(Int)
807859
processed[v] = true
808860
end
809861
empty!(dls.visited)
862+
empty!(diff_aliases)
810863
end
864+
@show dag
811865

866+
#=
812867
processed = falses(nvars)
868+
invag = SimpleDiGraph(nvars)
869+
for (v, (coeff, alias)) in pairs(ag)
870+
iszero(coeff) && continue
871+
add_edge!(invag, alias, v)
872+
end
813873
iag = InducedAliasGraph(ag, invag, var_to_diff)
814874
dag = AliasGraph(nvars) # alias graph for differentiated variables
815875
newinvag = SimpleDiGraph(nvars)
@@ -920,6 +980,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
920980
end
921981
end
922982
end
983+
=#
923984

924985
for (v, (c, a)) in dag
925986
a = iszero(a) ? 0 : c * fullvars[a]
@@ -949,6 +1010,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
9491010
push!(removed_aliases, a)
9501011
end
9511012
for (v, (c, a)) in ag
1013+
(processed[v] || processed[a]) && continue
9521014
v in removed_aliases && continue
9531015
freshag[v] = c => a
9541016
end
@@ -959,6 +1021,10 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
9591021
mm = reduce!(copy(echelon_mm), ag)
9601022
@warn "wow" mm
9611023
end
1024+
for (v, (c, a)) in ag
1025+
a = iszero(a) ? 0 : c * fullvars[a]
1026+
@info "ag" fullvars[v] => a
1027+
end
9621028

9631029
# Step 5: Reflect our update decisions back into the graph, and make sure
9641030
# that the RHS of observable variables are defined.

src/systems/sparsematrixclil.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ end
4444
struct CLILVector{T, Ti} <: AbstractSparseVector{T, Ti}
4545
vec::SparseVector{T, Ti}
4646
end
47+
Base.hash(v::CLILVector, s::UInt) = hash(v.vec, s) 0xc71be0e9ccb75fbd
4748
Base.size(v::CLILVector) = Base.size(v.vec)
4849
Base.getindex(v::CLILVector, idx::Integer...) = Base.getindex(v.vec, idx...)
4950
Base.setindex!(vec::CLILVector, v, idx::Integer...) = Base.setindex!(vec.vec, v, idx...)

0 commit comments

Comments
 (0)