Skip to content

Commit b08e81d

Browse files
committed
Mostly works. Needs sign handling in transitive closure to be fully functional
1 parent b464d72 commit b08e81d

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

src/structural_transformation/partial_state_selection.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
148148
end
149149

150150
function dummy_derivative_graph!(state::TransformationState, jac = nothing; kwargs...)
151+
Main._state[] = state
151152
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
152153
var_eq_matching = complete(pantelides!(state))
153154
complete!(state.structure)
@@ -191,6 +192,7 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
191192
iszero(maxlevel) && continue
192193

193194
rank_matching = Matching(nvars)
195+
isfirst = true
194196
for _ in maxlevel:-1:1
195197
eqs = filter(eq -> diff_to_eq[eq] !== nothing, eqs)
196198
nrows = length(eqs)
@@ -205,6 +207,10 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
205207
# state selection.)
206208
#
207209
# 3. If the Jacobian is a polynomial matrix, use Gröbner basis (?)
210+
if isfirst
211+
vars = sort(vars, by=i->occursin("A", string(Main._state[].fullvars[i])))
212+
end
213+
isfirst = false
208214
if jac !== nothing && (_J = jac(eqs, vars); all(x -> unwrap(x) isa Integer, _J))
209215
J = Int.(unwrap.(_J))
210216
N = ModelingToolkit.nullspace(J; col_order) # modifies col_order
@@ -229,6 +235,7 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
229235
if rank != nrows
230236
@warn "The DAE system is structurally singular!"
231237
end
238+
@info Main._state[].fullvars[vars]
232239

233240
# prepare the next iteration
234241
eqs = map(eq -> diff_to_eq[eq], eqs)

src/systems/alias_elimination.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,7 @@ function tograph(ag::AliasGraph, var_to_diff::DiffGraph)
298298
add_edge!(g, a, v)
299299
end
300300
transitiveclosure!(g)
301-
zero_vars_set = BitSet(zero_vars)
302-
for v in zero_vars
303-
for a in outneighbors(g, v)
304-
push!(zero_vars_set, a)
305-
end
306-
end
301+
#=
307302
# Compute the largest transitive closure that doesn't include any diff
308303
# edges.
309304
og = g
@@ -320,6 +315,8 @@ function tograph(ag::AliasGraph, var_to_diff::DiffGraph)
320315
end
321316
end
322317
g = newg
318+
=#
319+
eqg = copy(g)
323320

324321
c = "green"
325322
edge_styles = Dict{Tuple{Int, Int}, String}()
@@ -329,7 +326,7 @@ function tograph(ag::AliasGraph, var_to_diff::DiffGraph)
329326
add_edge!(g, v, dv)
330327
add_edge!(g, dv, v)
331328
end
332-
g, zero_vars_set, edge_styles
329+
g, eqg, zero_vars, edge_styles
333330
end
334331

335332
using Graphs.Experimental.Traversals
@@ -794,7 +791,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
794791
updated_diff_vars = Int[]
795792
diff_to_var = invview(var_to_diff)
796793
processed = falses(nvars)
797-
g, zero_vars_set = tograph(ag, var_to_diff)
794+
g, eqg, zero_vars = tograph(ag, var_to_diff)
798795
dls = DiffLevelState(g, var_to_diff)
799796
is_diff_edge = let var_to_diff = var_to_diff
800797
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
@@ -810,6 +807,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
810807
callback = Base.Fix1(push!, level_to_var))
811808
nlevels = length(level_to_var)
812809
prev_r = -1
810+
stem = Int[]
813811
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
814812
reach₌ = Pair{Int, Int}[]
815813
r === nothing || for n in neighbors(g, r)
@@ -841,17 +839,27 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
841839
end
842840
for (c, v) in reach₌
843841
v == prev_r && continue
842+
add_edge!(eqg, v, prev_r)
843+
push!(stem, prev_r)
844844
dag[v] = c => prev_r
845845
end
846846
push!(diff_aliases, reach₌)
847847
end
848-
for v in zero_vars_set
849-
dag[v] = 0
848+
@info "" fullvars[stem]
849+
transitiveclosure!(eqg)
850+
for i in 1:length(stem) - 1
851+
r, dr = stem[i], stem[i+1]
852+
if has_edge(eqg, r, dr)
853+
c = 1
854+
dag[dr] = c => r
855+
end
856+
end
857+
for v in zero_vars, a in outneighbors(g, v)
858+
dag[a] = 0
850859
end
851860
@show nlevels
852861
display(diff_aliases)
853862
@assert length(diff_aliases) == nlevels
854-
@show zero_vars_set
855863

856864
# clean up
857865
for v in dls.visited
@@ -861,6 +869,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
861869
empty!(dls.visited)
862870
empty!(diff_aliases)
863871
end
872+
for k in keys(dag)
873+
dag[k]
874+
end
864875
@show dag
865876

866877
#=
@@ -1010,7 +1021,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
10101021
push!(removed_aliases, a)
10111022
end
10121023
for (v, (c, a)) in ag
1013-
(processed[v] || processed[a]) && continue
1024+
(processed[v] || (!iszero(a) && processed[a])) && continue
10141025
v in removed_aliases && continue
10151026
freshag[v] = c => a
10161027
end

0 commit comments

Comments
 (0)