Skip to content

Commit cb0d444

Browse files
committed
Add solver_states
1 parent 3d391ee commit cb0d444

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

src/structural_transformation/codegen.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ function build_torn_function(sys;
297297
out,
298298
rhss)
299299

300-
states = fullvars[states_idxs]
300+
states = Any[fullvars[i] for i in states_idxs]
301+
@set! sys.solver_states = states
301302
syms = map(Symbol, states)
302303

303304
pre = get_postprocess_fbody(sys)

src/systems/abstractsystem.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ for prop in [:eqs
213213
:tearing_state
214214
:substitutions
215215
:metadata
216-
:discrete_subsystems]
216+
:discrete_subsystems
217+
:solver_states]
217218
fname1 = Symbol(:get_, prop)
218219
fname2 = Symbol(:has_, prop)
219220
@eval begin
@@ -466,7 +467,7 @@ function namespace_expr(O, sys, n = nameof(sys))
466467
end
467468
end
468469

469-
function SymbolicIndexingInterface.states(sys::AbstractSystem)
470+
function states(sys::AbstractSystem)
470471
sts = get_states(sys)
471472
systems = get_systems(sys)
472473
unique(isempty(systems) ?
@@ -580,8 +581,16 @@ end
580581

581582
SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))
582583

584+
function solver_states(sys::AbstractSystem)
585+
sts = states(sys)
586+
if has_solver_states(sys)
587+
sts = something(get_solver_states(sys), sts)
588+
end
589+
return sts
590+
end
591+
583592
function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym)
584-
findfirst(isequal(sym), SymbolicIndexingInterface.states(sys))
593+
findfirst(isequal(sym), solver_states(sys))
585594
end
586595
function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym)
587596
!isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym))

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,21 @@ struct ODESystem <: AbstractODESystem
127127
"""
128128
complete::Bool
129129
"""
130-
discrete_subsystems: a list of discrete subsystems
130+
discrete_subsystems: a list of discrete subsystems.
131131
"""
132132
discrete_subsystems::Any
133+
"""
134+
solver_states: a list of actual solver states. Only used for ODAEProblem.
135+
"""
136+
solver_states::Union{Nothing, Vector{Any}}
133137

134138
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
135139
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
136140
torn_matching, connector_type, preface, cevents,
137141
devents, metadata = nothing, tearing_state = nothing,
138142
substitutions = nothing, complete = false,
139-
discrete_subsystems = nothing; checks::Union{Bool, Int} = true)
143+
discrete_subsystems = nothing, solver_states = nothing;
144+
checks::Union{Bool, Int} = true)
140145
if checks == true || (checks & CheckComponents) > 0
141146
check_variables(dvs, iv)
142147
check_parameters(ps, iv)
@@ -149,7 +154,7 @@ struct ODESystem <: AbstractODESystem
149154
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
150155
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
151156
connector_type, preface, cevents, devents, metadata, tearing_state,
152-
substitutions, complete, discrete_subsystems)
157+
substitutions, complete, discrete_subsystems, solver_states)
153158
end
154159
end
155160

0 commit comments

Comments
 (0)