Skip to content

Commit 600e7ad

Browse files
fix: fix array varables split across SCCs in SCCNonlinearProblem
1 parent 5c80e09 commit 600e7ad

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -669,23 +669,23 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
669669
scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
670670
scc_eqs = Vector{Equation}[]
671671
scc_obs = Vector{Equation}[]
672+
# variables solved in previous SCCs
673+
available_vars = Set()
672674
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
673675
# subset unknowns and equations
674676
_dvs = dvs[vscc]
675677
_eqs = eqs[escc]
676678
# get observed equations required by this SCC
677-
obsidxs = observed_equations_used_by(sys, _eqs)
679+
union!(available_vars, _dvs)
680+
obsidxs = observed_equations_used_by(sys, _eqs; available_vars)
678681
# the ones used by previous SCCs can be precomputed into the cache
679682
setdiff!(obsidxs, prevobsidxs)
680683
_obs = obs[obsidxs]
684+
union!(available_vars, getproperty.(_obs, (:lhs,)))
681685

682686
# get all subexpressions in the RHS which we can precompute in the cache
683687
# precomputed subexpressions should not contain `banned_vars`
684688
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
685-
filter!(banned_vars) do var
686-
symbolic_type(var) != ArraySymbolic() ||
687-
all(j -> var[j] in banned_vars, eachindex(var))
688-
end
689689
state = Dict()
690690
for i in eachindex(_obs)
691691
_obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!(
@@ -746,9 +746,12 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
746746
_obs = scc_obs[i]
747747
cachevars = scc_cachevars[i]
748748
cacheexprs = scc_cacheexprs[i]
749+
available_vars = [dvs[reduce(vcat, var_sccs[1:(i - 1)]; init = Int[])];
750+
getproperty.(
751+
reduce(vcat, scc_obs[1:(i - 1)]; init = []), (:lhs,))]
749752
_prevobsidxs = vcat(_prevobsidxs,
750-
observed_equations_used_by(sys, reduce(vcat, values(cacheexprs); init = [])))
751-
753+
observed_equations_used_by(
754+
sys, reduce(vcat, values(cacheexprs); init = []); available_vars))
752755
if isempty(cachevars)
753756
push!(explicitfuns, Returns(nothing))
754757
else

test/scc_nonlinear_problem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,13 @@ import ModelingToolkitStandardLibrary.Hydraulic.IsothermalCompressible as IC
253253
sol = solve(prob)
254254
@test SciMLBase.successful_retcode(sol)
255255
end
256+
257+
@testset "Array variables split across SCCs" begin
258+
@variables x[1:3]
259+
@parameters (f::Function)(..)
260+
@mtkbuild sys = NonlinearSystem([
261+
0 ~ x[1]^2 - 9, x[2] ~ 2x[1], 0 ~ x[3]^2 - x[1]^2 + f(x)])
262+
prob = SCCNonlinearProblem(sys, [x => ones(3)], [f => sum])
263+
sol = solve(prob, NewtonRaphson())
264+
@test SciMLBase.successful_retcode(sol)
265+
end

0 commit comments

Comments
 (0)