Skip to content

Commit 58f4885

Browse files
fix: fix array varables split across SCCs in SCCNonlinearProblem
1 parent c10a7d2 commit 58f4885

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -666,23 +666,23 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
666666
scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
667667
scc_eqs = Vector{Equation}[]
668668
scc_obs = Vector{Equation}[]
669+
# variables solved in previous SCCs
670+
available_vars = Set()
669671
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
670672
# subset unknowns and equations
671673
_dvs = dvs[vscc]
672674
_eqs = eqs[escc]
673675
# get observed equations required by this SCC
674-
obsidxs = observed_equations_used_by(sys, _eqs)
676+
union!(available_vars, _dvs)
677+
obsidxs = observed_equations_used_by(sys, _eqs; available_vars)
675678
# the ones used by previous SCCs can be precomputed into the cache
676679
setdiff!(obsidxs, prevobsidxs)
677680
_obs = obs[obsidxs]
681+
union!(available_vars, getproperty.(_obs, (:lhs,)))
678682

679683
# get all subexpressions in the RHS which we can precompute in the cache
680684
# precomputed subexpressions should not contain `banned_vars`
681685
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
682-
filter!(banned_vars) do var
683-
symbolic_type(var) != ArraySymbolic() ||
684-
all(j -> var[j] in banned_vars, eachindex(var))
685-
end
686686
state = Dict()
687687
for i in eachindex(_obs)
688688
_obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!(
@@ -743,9 +743,12 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
743743
_obs = scc_obs[i]
744744
cachevars = scc_cachevars[i]
745745
cacheexprs = scc_cacheexprs[i]
746+
available_vars = [dvs[reduce(vcat, var_sccs[1:(i - 1)]; init = Int[])];
747+
getproperty.(
748+
reduce(vcat, scc_obs[1:(i - 1)]; init = []), (:lhs,))]
746749
_prevobsidxs = vcat(_prevobsidxs,
747-
observed_equations_used_by(sys, reduce(vcat, values(cacheexprs); init = [])))
748-
750+
observed_equations_used_by(
751+
sys, reduce(vcat, values(cacheexprs); init = []); available_vars))
749752
if isempty(cachevars)
750753
push!(explicitfuns, Returns(nothing))
751754
else

test/scc_nonlinear_problem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,13 @@ import ModelingToolkitStandardLibrary.Hydraulic.IsothermalCompressible as IC
253253
sol = solve(prob)
254254
@test SciMLBase.successful_retcode(sol)
255255
end
256+
257+
@testset "Array variables split across SCCs" begin
258+
@variables x[1:3]
259+
@parameters (f::Function)(..)
260+
@mtkbuild sys = NonlinearSystem([
261+
0 ~ x[1]^2 - 9, x[2] ~ 2x[1], 0 ~ x[3]^2 - x[1]^2 + f(x)])
262+
prob = SCCNonlinearProblem(sys, [x => ones(3)], [f => sum])
263+
sol = solve(prob, NewtonRaphson())
264+
@test SciMLBase.successful_retcode(sol)
265+
end

0 commit comments

Comments
 (0)