Skip to content

Commit 43ffe51

Browse files
committed
Add weighted transitive closure
1 parent f959c57 commit 43ffe51

File tree

3 files changed

+90
-52
lines changed

3 files changed

+90
-52
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
3636
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
3737
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
3838
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
39+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
3940
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4041
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4142
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -72,6 +73,7 @@ Reexport = "0.2, 1"
7273
RuntimeGeneratedFunctions = "0.4.3, 0.5"
7374
SciMLBase = "1.58.0"
7475
Setfield = "0.7, 0.8, 1"
76+
SimpleWeightedGraphs = "1"
7577
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7678
StaticArrays = "0.10, 0.11, 0.12, 1.0"
7779
SymbolicUtils = "0.19"
@@ -93,9 +95,9 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
9395
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
9496
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9597
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
98+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
9699
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
97100
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
98-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
99101
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
100102
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
101103
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"

src/systems/alias_elimination.jl

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using SymbolicUtils: Rewriters
2+
using SimpleWeightedGraphs
3+
using Graphs.Experimental.Traversals
24

35
const KEEP = typemin(Int)
46

@@ -288,35 +290,22 @@ end
288290

289291
function tograph(ag::AliasGraph, var_to_diff::DiffGraph)
290292
g = SimpleDiGraph{Int}(length(var_to_diff))
293+
eqg = SimpleWeightedGraph{Int, Int}(length(var_to_diff))
291294
zero_vars = Int[]
292-
for (v, (_, a)) in ag
295+
for (v, (c, a)) in ag
293296
if iszero(a)
294297
push!(zero_vars, v)
295298
continue
296299
end
297300
add_edge!(g, v, a)
298301
add_edge!(g, a, v)
302+
303+
add_edge!(eqg, v, a, c)
304+
add_edge!(eqg, a, v, c)
299305
end
300306
transitiveclosure!(g)
301-
#=
302-
# Compute the largest transitive closure that doesn't include any diff
303-
# edges.
304-
og = g
305-
newg = SimpleDiGraph{Int}(length(var_to_diff))
306-
for e in Graphs.edges(og)
307-
s, d = src(e), dst(e)
308-
(var_to_diff[s] == d || var_to_diff[d] == s) && continue
309-
oldg = copy(newg)
310-
add_edge!(newg, s, d)
311-
add_edge!(newg, d, s)
312-
transitiveclosure!(newg)
313-
if any(e->(var_to_diff[src(e)] == dst(e) || var_to_diff[dst(e)] == src(e)), edges(newg))
314-
newg = oldg
315-
end
316-
end
317-
g = newg
318-
=#
319-
eqg = copy(g)
307+
Main._a[] = copy(eqg)
308+
weighted_transitiveclosure!(eqg)
320309

321310
c = "green"
322311
edge_styles = Dict{Tuple{Int, Int}, String}()
@@ -329,7 +318,17 @@ function tograph(ag::AliasGraph, var_to_diff::DiffGraph)
329318
g, eqg, zero_vars, edge_styles
330319
end
331320

332-
using Graphs.Experimental.Traversals
321+
function weighted_transitiveclosure!(g)
322+
cps = connected_components(g)
323+
for cp in cps
324+
for k in cp, i in cp, j in cp
325+
(has_edge(g, i, k) && has_edge(g, k, j)) || continue
326+
add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j))
327+
end
328+
end
329+
return g
330+
end
331+
333332
struct DiffLevelState <: Traversals.AbstractTraversalState
334333
dists::Vector{Int}
335334
var_to_diff::DiffGraph
@@ -763,6 +762,10 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
763762
ag = AliasGraph(nvars)
764763
mm, echelon_mm = simple_aliases!(ag, graph, var_to_diff, mm_orig)
765764
fullvars = Main._state[].fullvars
765+
for (v, (c, a)) in ag
766+
a = iszero(a) ? 0 : c * fullvars[a]
767+
@info "ag" fullvars[v] => a
768+
end
766769

767770
# Step 3: Handle differentiated variables
768771
# At this point, `var_to_diff` and `ag` form a tree structure like the
@@ -802,17 +805,14 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
802805
(dv === nothing && diff_to_var[v] === nothing) && continue
803806
r = find_root!(dls, g, v)
804807
@show fullvars[r]
805-
level_to_var = Int[]
806-
extreme_var(var_to_diff, r, nothing, Val(false),
807-
callback = Base.Fix1(push!, level_to_var))
808-
nlevels = length(level_to_var)
809808
prev_r = -1
810809
stem = Int[]
810+
stem_set = BitSet()
811811
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
812812
reach₌ = Pair{Int, Int}[]
813-
r === nothing || for n in neighbors(g, r)
813+
r === nothing || for n in neighbors(eqg, r)
814814
(n == r || is_diff_edge(r, n)) && continue
815-
c = 1
815+
c = get_weight(eqg, r, n)
816816
push!(reach₌, c => n)
817817
end
818818
if (n = length(diff_aliases)) >= 1
@@ -823,45 +823,80 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
823823
push!(reach₌, c => da)
824824
end
825825
end
826-
for (c, a) in reach₌
827-
@info fullvars[r] => c * fullvars[a]
828-
end
829826
if r === nothing
830-
@warn "hi"
831-
# TODO: updated_diff_vars check
832827
isempty(reach₌) && break
833-
dr = first(reach₌)
828+
idx = findfirst(x->x[1] == 1, reach₌)
829+
if idx === nothing
830+
c, dr = reach₌[1]
831+
@assert c == -1
832+
dag[dr] = (c, dr)
833+
else
834+
c, dr = reach₌[idx]
835+
@assert c == 1
836+
end
834837
var_to_diff[prev_r] = dr
835838
push!(updated_diff_vars, prev_r)
836839
prev_r = dr
837840
else
838-
@warn "" fullvars[r]
839841
prev_r = r
840842
r = var_to_diff[r]
841843
end
842-
for (c, v) in reach₌
844+
for (c, a) in reach₌
845+
if r === nothing
846+
var_to_diff[prev_r] === nothing && continue
847+
@info fullvars[var_to_diff[prev_r]] => c * fullvars[a]
848+
else
849+
@info fullvars[r] => c * fullvars[a]
850+
end
851+
end
852+
prev_r in stem_set && break
853+
push!(stem_set, prev_r)
854+
push!(stem, prev_r)
855+
push!(diff_aliases, reach₌)
856+
for (_, v) in reach₌
843857
v == prev_r && continue
844858
add_edge!(eqg, v, prev_r)
845-
push!(stem, prev_r)
846-
dag[v] = c => prev_r
847859
end
848-
push!(diff_aliases, reach₌)
849860
end
861+
862+
@show fullvars[updated_diff_vars]
863+
@info "" fullvars
864+
@show stem
865+
@show diff_aliases
850866
@info "" fullvars[stem]
851-
transitiveclosure!(eqg)
867+
display(diff_aliases)
868+
@assert length(stem) == length(diff_aliases)
869+
for i in eachindex(stem)
870+
a = stem[i]
871+
for (c, v) in diff_aliases[i]
872+
# alias edges that coincide with diff edges are handled later
873+
v in stem_set && continue
874+
dag[v] = c => a
875+
end
876+
end
877+
# Obtain transitive closure after completing the alias edges from diff
878+
# edges.
879+
weighted_transitiveclosure!(eqg)
880+
# Canonicalize by preferring the lower differentiated variable
852881
for i in 1:length(stem) - 1
853-
r, dr = stem[i], stem[i+1]
854-
if has_edge(eqg, r, dr)
855-
c = 1
856-
dag[dr] = c => r
882+
r = stem[i]
883+
for dr in @view stem[i+1:end]
884+
if has_edge(eqg, r, dr)
885+
c = get_weight(eqg, r, dr)
886+
dag[dr] = c => r
887+
end
857888
end
858889
end
859-
for v in zero_vars, a in outneighbors(g, v)
860-
dag[a] = 0
890+
for v in zero_vars
891+
for a in Iterators.flatten((v, outneighbors(eqg, v)))
892+
while true
893+
dag[a] = 0
894+
da = var_to_diff[a]
895+
da === nothing && break
896+
a = da
897+
end
898+
end
861899
end
862-
@show nlevels
863-
display(diff_aliases)
864-
@assert length(diff_aliases) == nlevels
865900

866901
# clean up
867902
for v in dls.visited
@@ -871,6 +906,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
871906
empty!(dls.visited)
872907
empty!(diff_aliases)
873908
end
909+
@show dag
874910
for k in keys(dag)
875911
dag[k]
876912
end

test/reduction.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,14 @@ D = Differential(t)
263263
@named sys = ODESystem([D(x) ~ 1 - x,
264264
D(y) + D(x) ~ 0])
265265
new_sys = alias_elimination(sys)
266-
@test equations(new_sys) == [D(x) ~ 1 - x; 0 ~ -D(x) - D(y)]
266+
@test equations(new_sys) == [D(x) ~ 1 - x; D(x) + D(y) ~ 0]
267267
@test isempty(observed(new_sys))
268268

269269
@named sys = ODESystem([D(x) ~ x,
270270
D(y) + D(x) ~ 0])
271271
new_sys = alias_elimination(sys)
272-
@test equations(new_sys) == [0 ~ D(D(y)) - D(y)]
273-
@test observed(new_sys) == [x ~ -D(y)]
272+
@test equations(new_sys) == equations(sys)
273+
@test isempty(observed(new_sys))
274274

275275
@named sys = ODESystem([D(x) ~ 1 - x,
276276
y + D(x) ~ 0])

0 commit comments

Comments
 (0)