Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -676,23 +676,23 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
scc_eqs = Vector{Equation}[]
scc_obs = Vector{Equation}[]
# variables solved in previous SCCs
available_vars = Set()
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
# subset unknowns and equations
_dvs = dvs[vscc]
_eqs = eqs[escc]
# get observed equations required by this SCC
obsidxs = observed_equations_used_by(sys, _eqs)
union!(available_vars, _dvs)
obsidxs = observed_equations_used_by(sys, _eqs; available_vars)
# the ones used by previous SCCs can be precomputed into the cache
setdiff!(obsidxs, prevobsidxs)
_obs = obs[obsidxs]
union!(available_vars, getproperty.(_obs, (:lhs,)))

# get all subexpressions in the RHS which we can precompute in the cache
# precomputed subexpressions should not contain `banned_vars`
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
filter!(banned_vars) do var
symbolic_type(var) != ArraySymbolic() ||
all(j -> var[j] in banned_vars, eachindex(var))
end
state = Dict()
for i in eachindex(_obs)
_obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!(
Expand Down Expand Up @@ -753,9 +753,12 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
_obs = scc_obs[i]
cachevars = scc_cachevars[i]
cacheexprs = scc_cacheexprs[i]
available_vars = [dvs[reduce(vcat, var_sccs[1:(i - 1)]; init = Int[])];
getproperty.(
reduce(vcat, scc_obs[1:(i - 1)]; init = []), (:lhs,))]
_prevobsidxs = vcat(_prevobsidxs,
observed_equations_used_by(sys, reduce(vcat, values(cacheexprs); init = [])))

observed_equations_used_by(
sys, reduce(vcat, values(cacheexprs); init = []); available_vars))
if isempty(cachevars)
push!(explicitfuns, Returns(nothing))
else
Expand Down
12 changes: 11 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1068,14 +1068,24 @@ Keyword arguments:
providing this keyword is not necessary and is only useful to avoid repeatedly calling
`vars(exprs)`
- `obs`: the list of observed equations.
- `available_vars`: If `exprs` involves a variable `x[1]`, this function will look for
observed equations whose LHS is `x[1]` OR `x`. Sometimes, the latter is not required
since `x[1]` might already be present elsewhere in the generated code (e.g. an argument
to the function) but other elements of `x` are part of the observed equations, thus
requiring them to be obtained from the equation for `x`. Any variable present in
`available_vars` will not be searched for in the observed equations.
"""
function observed_equations_used_by(sys::AbstractSystem, exprs;
involved_vars = vars(exprs; op = Union{Shift, Differential}), obs = observed(sys))
involved_vars = vars(exprs; op = Union{Shift, Differential}), obs = observed(sys), available_vars = [])
obsvars = getproperty.(obs, :lhs)
graph = observed_dependency_graph(obs)
if !(available_vars isa Set)
available_vars = Set(available_vars)
end

obsidxs = BitSet()
for sym in involved_vars
sym in available_vars && continue
arrsym = iscall(sym) && operation(sym) === getindex ? arguments(sym)[1] : nothing
idx = findfirst(v -> isequal(v, sym) || isequal(v, arrsym), obsvars)
idx === nothing && continue
Expand Down
10 changes: 10 additions & 0 deletions test/scc_nonlinear_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@
β = 1e-6
R0 = 1000
R = 9000
Ue(t) = 0.1 * sin(200 * π * t)

Check warning on line 96 in test/scc_nonlinear_problem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Ue" should be "Use" or "Due".

function transamp(out, du, u, p, t)
g(x) = 1e-6 * (exp(x / 0.026) - 1)
y1, y2, y3, y4, y5, y6, y7, y8 = u
out[1] = -Ue(t) / R0 + y1 / R0 + C[1] * du[1] - C[1] * du[2]

Check warning on line 101 in test/scc_nonlinear_problem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Ue" should be "Use" or "Due".
out[2] = -Ub / R + y2 * 2 / R - (α - 1) * g(y2 - y3) - C[1] * du[1] + C[1] * du[2]
out[3] = -g(y2 - y3) + y3 / R + C[2] * du[3]
out[4] = -Ub / R + y4 / R + α * g(y2 - y3) + C[3] * du[4] - C[3] * du[5]
Expand Down Expand Up @@ -253,3 +253,13 @@
sol = solve(prob)
@test SciMLBase.successful_retcode(sol)
end

@testset "Array variables split across SCCs" begin
@variables x[1:3]
@parameters (f::Function)(..)
@mtkbuild sys = NonlinearSystem([
0 ~ x[1]^2 - 9, x[2] ~ 2x[1], 0 ~ x[3]^2 - x[1]^2 + f(x)])
prob = SCCNonlinearProblem(sys, [x => ones(3)], [f => sum])
sol = solve(prob, NewtonRaphson())
@test SciMLBase.successful_retcode(sol)
end
Loading