Skip to content

Commit 69c728b

Browse files
committed
Optimize dummy derivative graph
1 parent 20b9203 commit 69c728b

File tree

1 file changed

+57
-7
lines changed

1 file changed

+57
-7
lines changed

src/structural_transformation/partial_state_selection.jl

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,38 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
188188
dummy_derivatives = Int[]
189189
col_order = Int[]
190190
nvars = ndsts(graph)
191+
eqs = Int[]
192+
next_eq_idxs = Int[]
193+
next_var_idxs = Int[]
194+
new_eqs = Int[]
195+
new_vars = Int[]
191196
for vars in var_sccs
192-
eqs = [var_eq_matching[var] for var in vars if var_eq_matching[var] !== unassigned]
197+
empty!(eqs)
198+
for var in vars
199+
eq = var_eq_matching[var]
200+
eq isa Int || continue
201+
diff_to_eq[eq] === nothing && continue
202+
push!(eqs, eq)
203+
end
193204
isempty(eqs) && continue
194205
maxlevel = maximum(Base.Fix1(getindex, eqlevel), eqs)
195206
iszero(maxlevel) && continue
196207

197208
rank_matching = Matching(nvars)
198209
isfirst = true
199-
for _ in maxlevel:-1:1
200-
eqs = filter(eq -> diff_to_eq[eq] !== nothing, eqs)
210+
if jac === nothing
211+
J = nothing
212+
else
213+
_J = jac(eqs, vars)
214+
# only accecpt small intergers to avoid overflow
215+
is_all_small_int = all(_J) do x′
216+
x = unwrap(x′)
217+
x isa Number || return false
218+
isinteger(x) && typemin(Int8) <= x <= typemax(Int8)
219+
end
220+
J = is_all_small_int ? Int.(unwrap.(_J)) : nothing
221+
end
222+
for level in maxlevel:-1:1
201223
nrows = length(eqs)
202224
iszero(nrows) && break
203225
eqs_set = BitSet(eqs)
@@ -214,8 +236,10 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
214236
# state selection.)
215237
#
216238
# 3. If the Jacobian is a polynomial matrix, use Gröbner basis (?)
217-
if jac !== nothing && (_J = jac(eqs, vars); all(x -> unwrap(x) isa Integer, _J))
218-
J = Int.(unwrap.(_J))
239+
if J !== nothing
240+
if level < maxlevel
241+
J = J[next_eq_idxs, next_var_idxs]
242+
end
219243
N = ModelingToolkit.nullspace(J; col_order) # modifies col_order
220244
rank = length(col_order) - size(N, 2)
221245
for i in 1:rank
@@ -241,8 +265,34 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
241265
end
242266

243267
# prepare the next iteration
244-
eqs = map(eq -> diff_to_eq[eq], eqs)
245-
vars = [diff_to_var[var] for var in vars if diff_to_var[var] !== nothing]
268+
if J !== nothing
269+
empty!(next_eq_idxs)
270+
empty!(next_var_idxs)
271+
end
272+
empty!(new_eqs)
273+
empty!(new_vars)
274+
for (i, eq) in enumerate(eqs)
275+
∫eq = diff_to_eq[eq]
276+
# descend by one diff level, but the next iteration of equations
277+
# must still be differentiated
278+
∫eq === nothing && continue
279+
∫∫eq = diff_to_eq[∫eq]
280+
∫∫eq === nothing && continue
281+
if J !== nothing
282+
push!(next_eq_idxs, i)
283+
end
284+
push!(new_eqs, ∫eq)
285+
end
286+
for (i, var) in enumerate(vars)
287+
∫var = diff_to_var[var]
288+
∫var === nothing && continue
289+
if J !== nothing
290+
push!(next_var_idxs, i)
291+
end
292+
push!(new_vars, ∫var)
293+
end
294+
eqs, new_eqs = new_eqs, eqs
295+
vars, new_vars = new_vars, vars
246296
end
247297
end
248298

0 commit comments

Comments
 (0)