Skip to content

Commit 12b2efe

Browse files
authored
Merge pull request #1408 from SciML/myb/odaefix
Fix torn_system_jacobian_sparsity
2 parents 7fd455d + 6f16a57 commit 12b2efe

File tree

2 files changed

+42
-22
lines changed

2 files changed

+42
-22
lines changed

src/structural_transformation/codegen.jl

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

7-
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs)
7+
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
88
s = structure(sys)
99
@unpack fullvars, graph = s
1010

@@ -55,39 +55,39 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
5555
dig = DiCMOBiGraph{true}(graph, var_eq_matching)
5656

5757
fused_var_deps = map(1:ndsts(graph)) do v
58-
BitSet(var_rename[v′] for v′ in neighborhood(dig, v, Inf; dir=:in) if var_rename[v′] != 0)
58+
BitSet(v′ for v′ in neighborhood(dig, v, Inf; dir=:in) if var_rename[v′] != 0)
5959
end
6060

61-
for scc in var_sccs
61+
for scc in var_sccs[nlsolve_scc_idxs]
6262
if length(scc) >= 2
6363
deps = fused_var_deps[scc[1]]
6464
for c in 2:length(scc)
6565
union!(deps, fused_var_deps[c])
66+
fused_var_deps[c] = deps
6667
end
6768
end
6869
end
6970

70-
nlsolve_eqs = BitSet(var_eq_matching[c]::Int for c in nlsolve_vars if var_eq_matching[c] !== unassigned)
71-
72-
var2idx = Dict(v => i for (i, v) in enumerate(states_idxs))
71+
var2idx = Dict{Int,Int}(v => i for (i, v) in enumerate(states_idxs))
72+
eqs2idx = Dict{Int,Int}(v => i for (i, v) in enumerate(eqs_idxs))
7373
nlsolve_vars_set = BitSet(nlsolve_vars)
7474

7575
I = Int[]; J = Int[]
76-
eqidx = 0
7776
for ieq in 𝑠vertices(graph)
78-
ieq in nlsolve_eqs && continue
79-
eqidx += 1
77+
nieq = get(eqs2idx, ieq, 0)
78+
nieq == 0 && continue
8079
for ivar in 𝑠neighbors(graph, ieq)
8180
isdervar(s, ivar) && continue
8281
if var_rename[ivar] != 0
83-
push!(I, eqidx)
82+
push!(I, nieq)
8483
push!(J, var2idx[ivar])
8584
else
8685
for dvar in fused_var_deps[ivar]
8786
isdervar(s, dvar) && continue
88-
dvar in nlsolve_vars_set && continue
89-
push!(I, eqidx)
90-
push!(J, var2idx[dvar])
87+
niv = get(var2idx, dvar, 0)
88+
niv == 0 && continue
89+
push!(I, nieq)
90+
push!(J, niv)
9191
end
9292
end
9393
end
@@ -176,8 +176,11 @@ function build_torn_function(
176176
max_inlining_size = something(max_inlining_size, MAX_INLINE_NLSOLVE_SIZE)
177177
rhss = []
178178
eqs = equations(sys)
179-
for eq in eqs
180-
isdiffeq(eq) && push!(rhss, eq.rhs)
179+
eqs_idxs = Int[]
180+
for (i, eq) in enumerate(eqs)
181+
isdiffeq(eq) || continue
182+
push!(eqs_idxs, i)
183+
push!(rhss, eq.rhs)
181184
end
182185

183186
s = structure(sys)
@@ -198,16 +201,17 @@ function build_torn_function(
198201
for (i, scc) in enumerate(var_sccs)
199202
#torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
200203
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] !== unassigned]
201-
torn_eqs = [eqs[var_eq_matching[var]] for var in torn_vars_idxs]
202-
isempty(torn_eqs) && continue
203-
if length(torn_eqs) <= max_inlining_size
204-
append!(torn_expr, gen_nlsolve(torn_eqs, s.fullvars[torn_vars_idxs], defs, checkbounds=checkbounds))
204+
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
205+
isempty(torn_eqs_idxs) && continue
206+
if length(torn_eqs_idxs) <= max_inlining_size
207+
append!(torn_expr, gen_nlsolve(eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, checkbounds=checkbounds))
205208
push!(nlsolve_scc_idxs, i)
206209
else
207210
needs_extending = true
208-
append!(rhss, map(x->x.rhs, torn_eqs))
211+
append!(eqs_idxs, torn_eqs_idxs)
212+
append!(rhss, map(x->x.rhs, eqs[torn_eqs_idxs]))
209213
append!(states_idxs, torn_vars_idxs)
210-
append!(mass_matrix_diag, zeros(length(torn_eqs)))
214+
append!(mass_matrix_diag, zeros(length(torn_eqs_idxs)))
211215
end
212216
end
213217

@@ -253,7 +257,7 @@ function build_torn_function(
253257

254258
ODEFunction{true}(
255259
@RuntimeGeneratedFunction(expr),
256-
sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs) : nothing,
260+
sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing,
257261
syms = syms,
258262
observed = observedfun,
259263
mass_matrix = mass_matrix,

test/components.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
using Test
22
using ModelingToolkit, OrdinaryDiffEq
3+
using ModelingToolkit.BipartiteGraphs
4+
5+
function check_contract(sys)
6+
s = structure(sys)
7+
@unpack fullvars, graph = s
8+
eqs = equations(sys)
9+
var2idx = Dict(enumerate(fullvars))
10+
for (i, eq) in enumerate(eqs)
11+
actual = union(ModelingToolkit.vars(eq.lhs), ModelingToolkit.vars(eq.rhs))
12+
actual = filter(!ModelingToolkit.isparameter, collect(actual))
13+
current = Set(fullvars[𝑠neighbors(graph, i)])
14+
@test isempty(setdiff(actual, current))
15+
end
16+
end
317

418
include("../examples/rc_model.jl")
519

620
sys = structural_simplify(rc_model)
21+
check_contract(sys)
722
@test !isempty(ModelingToolkit.defaults(sys))
823
u0 = [
924
capacitor.v => 0.0
@@ -76,6 +91,7 @@ sol = solve(prob, Tsit5())
7691

7792
include("../examples/serial_inductor.jl")
7893
sys = structural_simplify(ll_model)
94+
check_contract(sys)
7995
u0 = [
8096
inductor1.i => 0.0
8197
inductor2.i => 0.0

0 commit comments

Comments
 (0)