Skip to content

Commit 7ee87a5

Browse files
fix: fix array varables split across SCCs in SCCNonlinearProblem
1 parent bb15228 commit 7ee87a5

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -676,23 +676,23 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
676676
scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
677677
scc_eqs = Vector{Equation}[]
678678
scc_obs = Vector{Equation}[]
679+
# variables solved in previous SCCs
680+
available_vars = Set()
679681
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
680682
# subset unknowns and equations
681683
_dvs = dvs[vscc]
682684
_eqs = eqs[escc]
683685
# get observed equations required by this SCC
684-
obsidxs = observed_equations_used_by(sys, _eqs)
686+
union!(available_vars, _dvs)
687+
obsidxs = observed_equations_used_by(sys, _eqs; available_vars)
685688
# the ones used by previous SCCs can be precomputed into the cache
686689
setdiff!(obsidxs, prevobsidxs)
687690
_obs = obs[obsidxs]
691+
union!(available_vars, getproperty.(_obs, (:lhs,)))
688692

689693
# get all subexpressions in the RHS which we can precompute in the cache
690694
# precomputed subexpressions should not contain `banned_vars`
691695
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
692-
filter!(banned_vars) do var
693-
symbolic_type(var) != ArraySymbolic() ||
694-
all(j -> var[j] in banned_vars, eachindex(var))
695-
end
696696
state = Dict()
697697
for i in eachindex(_obs)
698698
_obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!(
@@ -753,9 +753,12 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
753753
_obs = scc_obs[i]
754754
cachevars = scc_cachevars[i]
755755
cacheexprs = scc_cacheexprs[i]
756+
available_vars = [dvs[reduce(vcat, var_sccs[1:(i - 1)]; init = Int[])];
757+
getproperty.(
758+
reduce(vcat, scc_obs[1:(i - 1)]; init = []), (:lhs,))]
756759
_prevobsidxs = vcat(_prevobsidxs,
757-
observed_equations_used_by(sys, reduce(vcat, values(cacheexprs); init = [])))
758-
760+
observed_equations_used_by(
761+
sys, reduce(vcat, values(cacheexprs); init = []); available_vars))
759762
if isempty(cachevars)
760763
push!(explicitfuns, Returns(nothing))
761764
else

test/scc_nonlinear_problem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,13 @@ import ModelingToolkitStandardLibrary.Hydraulic.IsothermalCompressible as IC
253253
sol = solve(prob)
254254
@test SciMLBase.successful_retcode(sol)
255255
end
256+
257+
@testset "Array variables split across SCCs" begin
258+
@variables x[1:3]
259+
@parameters (f::Function)(..)
260+
@mtkbuild sys = NonlinearSystem([
261+
0 ~ x[1]^2 - 9, x[2] ~ 2x[1], 0 ~ x[3]^2 - x[1]^2 + f(x)])
262+
prob = SCCNonlinearProblem(sys, [x => ones(3)], [f => sum])
263+
sol = solve(prob, NewtonRaphson())
264+
@test SciMLBase.successful_retcode(sol)
265+
end

0 commit comments

Comments
 (0)