Skip to content

Commit 6d809e4

Browse files
authored
Merge pull request #2193 from SciML/myb/opt_ddg
Optimize dummy derivative graph
2 parents 85df648 + d70442d commit 6d809e4

File tree

1 file changed

+62
-31
lines changed

1 file changed

+62
-31
lines changed

src/structural_transformation/partial_state_selection.jl

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -175,55 +175,55 @@ function dummy_derivative_graph!(state::TransformationState, jac = nothing;
175175
dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority)
176176
end
177177

178-
function compute_diff_level(diff_to_x)
179-
nxs = length(diff_to_x)
180-
xlevel = zeros(Int, nxs)
181-
maxlevel = 0
182-
for i in 1:nxs
183-
level = 0
184-
x = i
185-
while diff_to_x[x] !== nothing
186-
x = diff_to_x[x]
187-
level += 1
188-
end
189-
maxlevel = max(maxlevel, level)
190-
xlevel[i] = level
191-
end
192-
return xlevel, maxlevel
193-
end
194-
195178
function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac,
196179
state_priority)
197180
@unpack eq_to_diff, var_to_diff, graph = structure
198181
diff_to_eq = invview(eq_to_diff)
199182
diff_to_var = invview(var_to_diff)
200183
invgraph = invview(graph)
201184

202-
eqlevel, _ = compute_diff_level(diff_to_eq)
203-
204185
var_sccs = find_var_sccs(graph, var_eq_matching)
205186
eqcolor = falses(nsrcs(graph))
206187
dummy_derivatives = Int[]
207188
col_order = Int[]
208189
nvars = ndsts(graph)
190+
eqs = Int[]
191+
next_eq_idxs = Int[]
192+
next_var_idxs = Int[]
193+
new_eqs = Int[]
194+
new_vars = Int[]
195+
eqs_set = BitSet()
209196
for vars in var_sccs
210-
eqs = [var_eq_matching[var] for var in vars if var_eq_matching[var] !== unassigned]
197+
empty!(eqs)
198+
for var in vars
199+
eq = var_eq_matching[var]
200+
eq isa Int || continue
201+
diff_to_eq[eq] === nothing && continue
202+
push!(eqs, eq)
203+
end
211204
isempty(eqs) && continue
212-
maxlevel = maximum(Base.Fix1(getindex, eqlevel), eqs)
213-
iszero(maxlevel) && continue
214205

215206
rank_matching = Matching(nvars)
216207
isfirst = true
217-
for _ in maxlevel:-1:1
218-
eqs = filter(eq -> diff_to_eq[eq] !== nothing, eqs)
208+
if jac === nothing
209+
J = nothing
210+
else
211+
_J = jac(eqs, vars)
212+
# only accecpt small intergers to avoid overflow
213+
is_all_small_int = all(_J) do x′
214+
x = unwrap(x′)
215+
x isa Number || return false
216+
isinteger(x) && typemin(Int8) <= x <= typemax(Int8)
217+
end
218+
J = is_all_small_int ? Int.(unwrap.(_J)) : nothing
219+
end
220+
while true
219221
nrows = length(eqs)
220222
iszero(nrows) && break
221-
eqs_set = BitSet(eqs)
222223

223224
if state_priority !== nothing && isfirst
224225
sort!(vars, by = state_priority)
225226
end
226-
isfirst = false
227227
# TODO: making the algorithm more robust
228228
# 1. If the Jacobian is a integer matrix, use Bareiss to check
229229
# linear independence. (done)
@@ -232,14 +232,18 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
232232
# state selection.)
233233
#
234234
# 3. If the Jacobian is a polynomial matrix, use Gröbner basis (?)
235-
if jac !== nothing && (_J = jac(eqs, vars); all(x -> unwrap(x) isa Integer, _J))
236-
J = Int.(unwrap.(_J))
235+
if J !== nothing
236+
if !isfirst
237+
J = J[next_eq_idxs, next_var_idxs]
238+
end
237239
N = ModelingToolkit.nullspace(J; col_order) # modifies col_order
238240
rank = length(col_order) - size(N, 2)
239241
for i in 1:rank
240242
push!(dummy_derivatives, vars[col_order[i]])
241243
end
242244
else
245+
empty!(eqs_set)
246+
union!(eqs_set, eqs)
243247
rank = 0
244248
for var in vars
245249
eqcolor .= false
@@ -255,12 +259,39 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
255259
fill!(rank_matching, unassigned)
256260
end
257261
if rank != nrows
258-
@warn "The DAE system is structurally singular!"
262+
@warn "The DAE system is singular!"
259263
end
260264

261265
# prepare the next iteration
262-
eqs = map(eq -> diff_to_eq[eq], eqs)
263-
vars = [diff_to_var[var] for var in vars if diff_to_var[var] !== nothing]
266+
if J !== nothing
267+
empty!(next_eq_idxs)
268+
empty!(next_var_idxs)
269+
end
270+
empty!(new_eqs)
271+
empty!(new_vars)
272+
for (i, eq) in enumerate(eqs)
273+
∫eq = diff_to_eq[eq]
274+
# descend by one diff level, but the next iteration of equations
275+
# must still be differentiated
276+
∫eq === nothing && continue
277+
∫∫eq = diff_to_eq[∫eq]
278+
∫∫eq === nothing && continue
279+
if J !== nothing
280+
push!(next_eq_idxs, i)
281+
end
282+
push!(new_eqs, ∫eq)
283+
end
284+
for (i, var) in enumerate(vars)
285+
∫var = diff_to_var[var]
286+
∫var === nothing && continue
287+
if J !== nothing
288+
push!(next_var_idxs, i)
289+
end
290+
push!(new_vars, ∫var)
291+
end
292+
eqs, new_eqs = new_eqs, eqs
293+
vars, new_vars = new_vars, vars
294+
isfirst = false
264295
end
265296
end
266297

0 commit comments

Comments
 (0)