Skip to content

Commit 5f74f30

Browse files
feat: better handle observed variables, constants in SCCNonlinearProblem
1 parent ff79993 commit 5f74f30

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -539,17 +539,22 @@ function (cw::CacheWriter)(p, sols)
539539
cw.fn(p.caches[1], sols, p...)
540540
end
541541

542-
function CacheWriter(sys::AbstractSystem, exprs, solsyms;
542+
function CacheWriter(sys::AbstractSystem, exprs, solsyms, obseqs::Vector{Equation};
543543
eval_expression = false, eval_module = @__MODULE__)
544544
ps = parameters(sys)
545545
rps = reorder_parameters(sys, ps)
546+
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
547+
cmap, cs = get_cmap(sys)
548+
cmap_assigns = [eq.lhs eq.rhs for eq in cmap]
546549
fn = Func(
547550
[:out, DestructuredArgs(DestructuredArgs.(solsyms)),
548551
DestructuredArgs.(rps)...],
549552
[],
550553
SetArray(true, :out, exprs)
551-
) |> wrap_parameter_dependencies(sys, false)[2] |>
552-
wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |> toexpr
554+
) |> wrap_assignments(false, obs_assigns)[2] |>
555+
wrap_parameter_dependencies(sys, false)[2] |>
556+
wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |>
557+
wrap_assignments(false, cmap_assigns)[2] |> toexpr
553558
return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module))
554559
end
555560

@@ -608,7 +613,8 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
608613
var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts)
609614

610615
if length(var_sccs) == 1
611-
return NonlinearProblem{iip}(sys, u0map, parammap; eval_expression, eval_module, kwargs...)
616+
return NonlinearProblem{iip}(
617+
sys, u0map, parammap; eval_expression, eval_module, kwargs...)
612618
end
613619

614620
condensed_graph = MatchedCondensationGraph(
@@ -660,7 +666,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
660666

661667
# cached variables and their corresponding expressions
662668
cachevars = Any[obs[i].lhs for i in prevobsidxs]
663-
cacheexprs = Any[obs[i].rhs for i in prevobsidxs]
669+
cacheexprs = Any[obs[i].lhs for i in prevobsidxs]
664670
for (k, v) in state
665671
push!(cachevars, unwrap(v))
666672
push!(cacheexprs, unwrap(k))
@@ -672,7 +678,8 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
672678
else
673679
solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1)))
674680
push!(explicitfuns,
675-
CacheWriter(sys, cacheexprs, solsyms; eval_expression, eval_module))
681+
CacheWriter(sys, cacheexprs, solsyms, obs[prevobsidxs];
682+
eval_expression, eval_module))
676683
end
677684
f = SCCNonlinearFunction{iip}(
678685
sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...)

0 commit comments

Comments
 (0)