@@ -47,7 +47,7 @@ function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true)
47
47
level === nothing ? v : (v => level)
48
48
end
49
49
50
- alias_elimination (sys) = alias_elimination! (TearingState (sys))
50
+ alias_elimination (sys) = alias_elimination! (TearingState (sys))[ 1 ]
51
51
function alias_elimination! (state:: TearingState )
52
52
sys = state. sys
53
53
complete! (state. structure)
@@ -56,7 +56,7 @@ function alias_elimination!(state::TearingState)
56
56
isempty (ag) && return sys
57
57
58
58
fullvars = state. fullvars
59
- @unpack var_to_diff, graph = state. structure
59
+ @unpack var_to_diff, graph, solvable_graph = state. structure
60
60
61
61
if ! isempty (updated_diff_vars)
62
62
has_iv (sys) ||
@@ -105,19 +105,36 @@ function alias_elimination!(state::TearingState)
105
105
end
106
106
end
107
107
deleteat! (eqs, sort! (dels))
108
- old_to_new = Vector {Int} (undef, nsrcs (graph))
108
+ old_to_new_eq = Vector {Int} (undef, nsrcs (graph))
109
109
idx = 0
110
110
cursor = 1
111
111
ndels = length (dels)
112
- for i in eachindex (old_to_new )
112
+ for i in eachindex (old_to_new_eq )
113
113
if cursor <= ndels && i == dels[cursor]
114
114
cursor += 1
115
- old_to_new [i] = - 1
115
+ old_to_new_eq [i] = - 1
116
116
continue
117
117
end
118
118
idx += 1
119
- old_to_new [i] = idx
119
+ old_to_new_eq [i] = idx
120
120
end
121
+ n_new_eqs = idx
122
+
123
+ old_to_new_var = Vector {Int} (undef, ndsts (graph))
124
+ idx = 0
125
+ for i in eachindex (old_to_new_var)
126
+ if haskey (ag, i)
127
+ old_to_new_var[i] = - 1
128
+ else
129
+ idx += 1
130
+ old_to_new_var[i] = idx
131
+ end
132
+ end
133
+ n_new_vars = idx
134
+ # for d in dels
135
+ # set_neighbors!(graph, d, ())
136
+ # set_neighbors!(solvable_graph, d, ())
137
+ # end
121
138
122
139
lineqs = BitSet (mm. nzrows)
123
140
eqs_to_update = BitSet ()
@@ -126,7 +143,7 @@ function alias_elimination!(state::TearingState)
126
143
while true
127
144
for ieq in 𝑑neighbors (graph_orig, k)
128
145
ieq in lineqs && continue
129
- new_eq = old_to_new [ieq]
146
+ new_eq = old_to_new_eq [ieq]
130
147
new_eq < 1 && continue
131
148
push! (eqs_to_update, new_eq)
132
149
end
@@ -139,7 +156,7 @@ function alias_elimination!(state::TearingState)
139
156
end
140
157
141
158
for old_ieq in to_expand
142
- ieq = old_to_new [old_ieq]
159
+ ieq = old_to_new_eq [old_ieq]
143
160
eqs[ieq] = expand_derivatives (eqs[ieq])
144
161
end
145
162
@@ -150,12 +167,53 @@ function alias_elimination!(state::TearingState)
150
167
diff_to_var[j] === nothing && push! (newstates, fullvars[j])
151
168
end
152
169
end
170
+ #=
171
+ new_graph = BipartiteGraph(n_new_eqs, ndsts(graph))
172
+ new_solvable_graph = BipartiteGraph(n_new_eqs, ndsts(graph))
173
+ new_eq_to_diff = DiffGraph(n_new_eqs)
174
+ eq_to_diff = state.structure.eq_to_diff
175
+ for (i, ieq) in enumerate(old_to_new_eq)
176
+ ieq > 0 || continue
177
+ set_neighbors!(new_graph, ieq, 𝑠neighbors(graph, i))
178
+ set_neighbors!(new_solvable_graph, ieq, 𝑠neighbors(solvable_graph, i))
179
+ new_eq_to_diff[ieq] = eq_to_diff[i]
180
+ end
181
+ state.structure.graph = new_graph
182
+ state.structure.solvable_graph = new_solvable_graph
183
+ state.structure.eq_to_diff = new_eq_to_diff
184
+ @show length(new_eq_to_diff), nsrcs(new_graph), nsrcs(new_solvable_graph), length(eqs)
185
+ =#
186
+
187
+ new_graph = BipartiteGraph (n_new_eqs, n_new_vars)
188
+ new_solvable_graph = BipartiteGraph (n_new_eqs, n_new_vars)
189
+ new_eq_to_diff = DiffGraph (n_new_eqs)
190
+ eq_to_diff = state. structure. eq_to_diff
191
+ new_var_to_diff = DiffGraph (n_new_vars)
192
+ var_to_diff = state. structure. var_to_diff
193
+ for (i, ieq) in enumerate (old_to_new_eq)
194
+ ieq > 0 || continue
195
+ set_neighbors! (new_graph, ieq, [old_to_new_var[v] for v in 𝑠neighbors (graph, i) if old_to_new_var[v] > 0 ])
196
+ set_neighbors! (new_solvable_graph, ieq, [old_to_new_var[v] for v in 𝑠neighbors (solvable_graph, i) if old_to_new_var[v] > 0 ])
197
+ new_eq_to_diff[ieq] = eq_to_diff[i]
198
+ end
199
+ new_fullvars = Vector {Any} (undef, n_new_vars)
200
+ for (i, iv) in enumerate (old_to_new_var)
201
+ iv > 0 || continue
202
+ new_var_to_diff[iv] = var_to_diff[i]
203
+ new_fullvars[iv] = fullvars[i]
204
+ end
205
+ state. structure. graph = new_graph
206
+ state. structure. solvable_graph = new_solvable_graph
207
+ state. structure. eq_to_diff = complete (new_eq_to_diff)
208
+ state. structure. var_to_diff = complete (new_var_to_diff)
209
+ state. fullvars = new_fullvars
153
210
154
211
sys = state. sys
155
212
@set! sys. eqs = eqs
156
213
@set! sys. states = newstates
157
214
@set! sys. observed = [observed (sys); obs]
158
- return invalidate_cache! (sys)
215
+ state. sys = sys
216
+ return invalidate_cache! (sys), ag
159
217
end
160
218
161
219
"""
0 commit comments