Skip to content

Commit c8fb745

Browse files
committed
WIP: update TearingState in alias_elimination
1 parent 31f611b commit c8fb745

File tree

3 files changed

+79
-17
lines changed

3 files changed

+79
-17
lines changed

src/structural_transformation/utils.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ end
4343
###
4444
### Structural check
4545
###
46-
function check_consistency(state::TearingState)
46+
function check_consistency(state::TearingState, ag = nothing)
4747
fullvars = state.fullvars
4848
@unpack graph, var_to_diff = state.structure
49+
#n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0 && !isempty(𝑑neighbors(graph, v)),
50+
# vertices(var_to_diff))
51+
#n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0 && !haskey(ag, v),
52+
# vertices(var_to_diff))
4953
n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0,
5054
vertices(var_to_diff))
5155
neqs = nsrcs(graph)
@@ -69,11 +73,11 @@ function check_consistency(state::TearingState)
6973
# details, check the equation (15) of the original paper.
7074
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
7175
map(collect, edges(var_to_diff))])
72-
extended_var_eq_matching = maximal_matching(extended_graph)
76+
extended_var_eq_matching = maximal_matching(extended_graph, eq->true, v->!haskey(ag, v))
7377

7478
unassigned_var = []
7579
for (vj, eq) in enumerate(extended_var_eq_matching)
76-
if eq === unassigned
80+
if eq === unassigned && !haskey(ag, vj)
7781
push!(unassigned_var, fullvars[vj])
7882
end
7983
end

src/systems/abstractsystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,12 +1035,12 @@ function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false
10351035
has_io = io !== nothing
10361036
has_io && markio!(state, io...)
10371037
state, input_idxs = inputs_to_parameters!(state, io)
1038-
sys = alias_elimination!(state)
1038+
sys, ag = alias_elimination!(state)
10391039
# TODO: avoid construct `TearingState` again.
1040-
state = TearingState(sys)
1041-
has_io && markio!(state, io..., check = false)
1042-
check_consistency(state)
1043-
find_solvables!(state; kwargs...)
1040+
#state = TearingState(sys)
1041+
#has_io && markio!(state, io..., check = false)
1042+
check_consistency(state, ag)
1043+
#find_solvables!(state; kwargs...)
10441044
sys = dummy_derivative(sys, state; simplify)
10451045
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
10461046
@set! sys.observed = topsort_equations(observed(sys), fullstates)

src/systems/alias_elimination.jl

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true)
4747
level === nothing ? v : (v => level)
4848
end
4949

50-
alias_elimination(sys) = alias_elimination!(TearingState(sys))
50+
alias_elimination(sys) = alias_elimination!(TearingState(sys))[1]
5151
function alias_elimination!(state::TearingState)
5252
sys = state.sys
5353
complete!(state.structure)
@@ -56,7 +56,7 @@ function alias_elimination!(state::TearingState)
5656
isempty(ag) && return sys
5757

5858
fullvars = state.fullvars
59-
@unpack var_to_diff, graph = state.structure
59+
@unpack var_to_diff, graph, solvable_graph = state.structure
6060

6161
if !isempty(updated_diff_vars)
6262
has_iv(sys) ||
@@ -105,19 +105,36 @@ function alias_elimination!(state::TearingState)
105105
end
106106
end
107107
deleteat!(eqs, sort!(dels))
108-
old_to_new = Vector{Int}(undef, nsrcs(graph))
108+
old_to_new_eq = Vector{Int}(undef, nsrcs(graph))
109109
idx = 0
110110
cursor = 1
111111
ndels = length(dels)
112-
for i in eachindex(old_to_new)
112+
for i in eachindex(old_to_new_eq)
113113
if cursor <= ndels && i == dels[cursor]
114114
cursor += 1
115-
old_to_new[i] = -1
115+
old_to_new_eq[i] = -1
116116
continue
117117
end
118118
idx += 1
119-
old_to_new[i] = idx
119+
old_to_new_eq[i] = idx
120120
end
121+
n_new_eqs = idx
122+
123+
old_to_new_var = Vector{Int}(undef, ndsts(graph))
124+
idx = 0
125+
for i in eachindex(old_to_new_var)
126+
if haskey(ag, i)
127+
old_to_new_var[i] = -1
128+
else
129+
idx += 1
130+
old_to_new_var[i] = idx
131+
end
132+
end
133+
n_new_vars = idx
134+
#for d in dels
135+
# set_neighbors!(graph, d, ())
136+
# set_neighbors!(solvable_graph, d, ())
137+
#end
121138

122139
lineqs = BitSet(mm.nzrows)
123140
eqs_to_update = BitSet()
@@ -126,7 +143,7 @@ function alias_elimination!(state::TearingState)
126143
while true
127144
for ieq in 𝑑neighbors(graph_orig, k)
128145
ieq in lineqs && continue
129-
new_eq = old_to_new[ieq]
146+
new_eq = old_to_new_eq[ieq]
130147
new_eq < 1 && continue
131148
push!(eqs_to_update, new_eq)
132149
end
@@ -139,7 +156,7 @@ function alias_elimination!(state::TearingState)
139156
end
140157

141158
for old_ieq in to_expand
142-
ieq = old_to_new[old_ieq]
159+
ieq = old_to_new_eq[old_ieq]
143160
eqs[ieq] = expand_derivatives(eqs[ieq])
144161
end
145162

@@ -150,12 +167,53 @@ function alias_elimination!(state::TearingState)
150167
diff_to_var[j] === nothing && push!(newstates, fullvars[j])
151168
end
152169
end
170+
#=
171+
new_graph = BipartiteGraph(n_new_eqs, ndsts(graph))
172+
new_solvable_graph = BipartiteGraph(n_new_eqs, ndsts(graph))
173+
new_eq_to_diff = DiffGraph(n_new_eqs)
174+
eq_to_diff = state.structure.eq_to_diff
175+
for (i, ieq) in enumerate(old_to_new_eq)
176+
ieq > 0 || continue
177+
set_neighbors!(new_graph, ieq, 𝑠neighbors(graph, i))
178+
set_neighbors!(new_solvable_graph, ieq, 𝑠neighbors(solvable_graph, i))
179+
new_eq_to_diff[ieq] = eq_to_diff[i]
180+
end
181+
state.structure.graph = new_graph
182+
state.structure.solvable_graph = new_solvable_graph
183+
state.structure.eq_to_diff = new_eq_to_diff
184+
@show length(new_eq_to_diff), nsrcs(new_graph), nsrcs(new_solvable_graph), length(eqs)
185+
=#
186+
187+
new_graph = BipartiteGraph(n_new_eqs, n_new_vars)
188+
new_solvable_graph = BipartiteGraph(n_new_eqs, n_new_vars)
189+
new_eq_to_diff = DiffGraph(n_new_eqs)
190+
eq_to_diff = state.structure.eq_to_diff
191+
new_var_to_diff = DiffGraph(n_new_vars)
192+
var_to_diff = state.structure.var_to_diff
193+
for (i, ieq) in enumerate(old_to_new_eq)
194+
ieq > 0 || continue
195+
set_neighbors!(new_graph, ieq, [old_to_new_var[v] for v in 𝑠neighbors(graph, i) if old_to_new_var[v] > 0])
196+
set_neighbors!(new_solvable_graph, ieq, [old_to_new_var[v] for v in 𝑠neighbors(solvable_graph, i) if old_to_new_var[v] > 0])
197+
new_eq_to_diff[ieq] = eq_to_diff[i]
198+
end
199+
new_fullvars = Vector{Any}(undef, n_new_vars)
200+
for (i, iv) in enumerate(old_to_new_var)
201+
iv > 0 || continue
202+
new_var_to_diff[iv] = var_to_diff[i]
203+
new_fullvars[iv] = fullvars[i]
204+
end
205+
state.structure.graph = new_graph
206+
state.structure.solvable_graph = new_solvable_graph
207+
state.structure.eq_to_diff = complete(new_eq_to_diff)
208+
state.structure.var_to_diff = complete(new_var_to_diff)
209+
state.fullvars = new_fullvars
153210

154211
sys = state.sys
155212
@set! sys.eqs = eqs
156213
@set! sys.states = newstates
157214
@set! sys.observed = [observed(sys); obs]
158-
return invalidate_cache!(sys)
215+
state.sys = sys
216+
return invalidate_cache!(sys), ag
159217
end
160218

161219
"""

0 commit comments

Comments
 (0)