Skip to content

Commit c1f2e70

Browse files
committed
fix ODAEProblem Jacobian sparsity
1 parent ada195f commit c1f2e70

File tree

1 file changed

+48
-46
lines changed

1 file changed

+48
-46
lines changed

src/structural_transformation/codegen.jl

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

7-
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
7+
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs)
88
s = structure(sys)
99
@unpack fullvars, graph = s
1010

@@ -42,54 +42,52 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
4242
# from previous partitions. Hence, we can build the dependency chain as we
4343
# traverse the partitions.
4444

45-
# `avars2dvars` maps a algebraic variable to its differential variable
46-
# dependencies.
47-
avars2dvars = Dict{Int,Set{Int}}()
48-
c = 0
49-
for scc in var_sccs
50-
v_residual = scc
51-
e_residual = [var_eq_matching[c] for c in v_residual if var_eq_matching[c] !== unassigned]
52-
# initialization
53-
for tvar in v_residual
54-
avars2dvars[tvar] = Set{Int}()
45+
var_rename = ones(Int64, ndsts(graph))
46+
nlsolve_vars = Int[]
47+
for i in nlsolve_scc_idxs, c in var_sccs[i]
48+
append!(nlsolve_vars, c)
49+
for v in c
50+
var_rename[v] = 0
5551
end
56-
for teq in e_residual
57-
c += 1
58-
for var in 𝑠neighbors(graph, teq)
59-
# Skip the tearing variables in the current partition, because
60-
# we are computing them from all the other states.
61-
Graphs.insorted(var, v_residual) && continue
62-
deps = get(avars2dvars, var, nothing)
63-
if deps === nothing # differential variable
64-
@assert !isalgvar(s, var)
65-
for tvar in v_residual
66-
push!(avars2dvars[tvar], var)
67-
end
68-
else # tearing variable from previous partitions
69-
@assert isalgvar(s, var)
70-
for tvar in v_residual
71-
union!(avars2dvars[tvar], avars2dvars[var])
72-
end
73-
end
52+
end
53+
masked_cumsum!(var_rename)
54+
55+
dig = DiCMOBiGraph{true}(graph, var_eq_matching)
56+
57+
fused_var_deps = map(1:ndsts(graph)) do v
58+
BitSet(var_rename[v′] for v′ in neighborhood(dig, v, Inf; dir=:in) if var_rename[v′] != 0)
59+
end
60+
61+
for scc in var_sccs
62+
if length(scc) >= 2
63+
deps = fused_var_deps[scc[1]]
64+
for c in 2:length(scc)
65+
union!(deps, fused_var_deps[c])
7466
end
7567
end
7668
end
7769

78-
dvrange = diffvars_range(s)
79-
dvar2idx = Dict(v=>i for (i, v) in enumerate(dvrange))
70+
nlsolve_eqs = BitSet(var_eq_matching[c]::Int for c in nlsolve_vars if var_eq_matching[c] !== unassigned)
71+
72+
var2idx = Dict(v => i for (i, v) in enumerate(states_idxs))
73+
nlsolve_vars_set = BitSet(nlsolve_vars)
74+
8075
I = Int[]; J = Int[]
8176
eqidx = 0
8277
for ieq in 𝑠vertices(graph)
83-
isalgeq(s, ieq) && continue
78+
ieq in nlsolve_eqs && continue
8479
eqidx += 1
8580
for ivar in 𝑠neighbors(graph, ieq)
86-
if isdiffvar(s, ivar)
81+
isdervar(s, ivar) && continue
82+
if var_rename[ivar] != 0
8783
push!(I, eqidx)
88-
push!(J, dvar2idx[ivar])
89-
elseif isalgvar(s, ivar)
90-
for dvar in avars2dvars[ivar]
84+
push!(J, var2idx[ivar])
85+
else
86+
for dvar in fused_var_deps[ivar]
87+
isdervar(s, dvar) && continue
88+
dvar in nlsolve_vars_set && continue
9189
push!(I, eqidx)
92-
push!(J, dvar2idx[dvar])
90+
push!(J, var2idx[dvar])
9391
end
9492
end
9593
end
@@ -180,22 +178,25 @@ function build_torn_function(
180178
toporder = topological_sort_by_dfs(condensed_graph)
181179
var_sccs = var_sccs[toporder]
182180

183-
states = map(i->s.fullvars[i], diffvars_range(s))
184-
mass_matrix_diag = ones(length(states))
181+
states_idxs = collect(diffvars_range(s))
182+
mass_matrix_diag = ones(length(states_idxs))
185183
torn_expr = []
186184
defs = defaults(sys)
185+
nlsolve_scc_idxs = Int[]
187186

188187
needs_extending = false
189-
for scc in var_sccs
190-
torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
191-
torn_eqs = [eqs[var_eq_matching[var]] for var in scc if var_eq_matching[var] !== unassigned]
188+
for (i, scc) in enumerate(var_sccs)
189+
#torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
190+
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] !== unassigned]
191+
torn_eqs = [eqs[var_eq_matching[var]] for var in torn_vars_idxs]
192192
isempty(torn_eqs) && continue
193193
if length(torn_eqs) <= max_inlining_size
194-
append!(torn_expr, gen_nlsolve(torn_eqs, torn_vars, defs, checkbounds=checkbounds))
194+
append!(torn_expr, gen_nlsolve(torn_eqs, s.fullvars[torn_vars_idxs], defs, checkbounds=checkbounds))
195+
push!(nlsolve_scc_idxs, i)
195196
else
196197
needs_extending = true
197198
append!(rhss, map(x->x.rhs, torn_eqs))
198-
append!(states, torn_vars)
199+
append!(states_idxs, torn_vars_idxs)
199200
append!(mass_matrix_diag, zeros(length(torn_eqs)))
200201
end
201202
end
@@ -209,7 +210,8 @@ function build_torn_function(
209210
rhss
210211
)
211212

212-
syms = map(Symbol, states)
213+
states = s.fullvars[states_idxs]
214+
syms = map(Symbol, states_idxs)
213215
pre = get_postprocess_fbody(sys)
214216

215217
expr = SymbolicUtils.Code.toexpr(
@@ -241,7 +243,7 @@ function build_torn_function(
241243

242244
ODEFunction{true}(
243245
@RuntimeGeneratedFunction(expr),
244-
sparsity = torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs),
246+
sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs) : nothing,
245247
syms = syms,
246248
observed = observedfun,
247249
mass_matrix = mass_matrix,

0 commit comments

Comments
 (0)