Skip to content

Commit 3afb1ff

Browse files
authored
Merge pull request #2016 from SciML/myb/solver_states
Add `unknown_states`
2 parents c84300c + 8669cd3 commit 3afb1ff

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed

src/structural_transformation/codegen.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ function build_torn_function(sys;
297297
out,
298298
rhss)
299299

300-
states = fullvars[states_idxs]
300+
states = Any[fullvars[i] for i in states_idxs]
301+
@set! sys.unknown_states = states
301302
syms = map(Symbol, states)
302303

303304
pre = get_postprocess_fbody(sys)
@@ -402,7 +403,7 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs,
402403

403404
fullvars = state.fullvars
404405
s = state.structure
405-
solver_states = fullvars[is_solver_state_idxs]
406+
unknown_states = fullvars[is_solver_state_idxs]
406407
algvars = fullvars[.!is_solver_state_idxs]
407408

408409
required_algvars = Set(intersect(algvars, vars))
@@ -489,7 +490,7 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs,
489490
cpre = get_preprocess_constants([obs[1:maxidx];
490491
isscalar ? ts[1] : MakeArray(ts, output_type)])
491492
pre2 = x -> pre(cpre(x))
492-
ex = Code.toexpr(Func([DestructuredArgs(solver_states, inbounds = !checkbounds)
493+
ex = Code.toexpr(Func([DestructuredArgs(unknown_states, inbounds = !checkbounds)
493494
DestructuredArgs(parameters(sys), inbounds = !checkbounds)
494495
independent_variables(sys)],
495496
[],

src/systems/abstractsystem.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ for prop in [:eqs
213213
:tearing_state
214214
:substitutions
215215
:metadata
216-
:discrete_subsystems]
216+
:discrete_subsystems
217+
:unknown_states]
217218
fname1 = Symbol(:get_, prop)
218219
fname2 = Symbol(:has_, prop)
219220
@eval begin
@@ -466,7 +467,7 @@ function namespace_expr(O, sys, n = nameof(sys))
466467
end
467468
end
468469

469-
function SymbolicIndexingInterface.states(sys::AbstractSystem)
470+
function states(sys::AbstractSystem)
470471
sts = get_states(sys)
471472
systems = get_systems(sys)
472473
unique(isempty(systems) ?
@@ -580,8 +581,21 @@ end
580581

581582
SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))
582583

584+
"""
585+
$(SIGNATURES)
586+
587+
Return a list of actual states needed to be solved by solvers.
588+
"""
589+
function unknown_states(sys::AbstractSystem)
590+
sts = states(sys)
591+
if has_unknown_states(sys)
592+
sts = something(get_unknown_states(sys), sts)
593+
end
594+
return sts
595+
end
596+
583597
function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym)
584-
findfirst(isequal(sym), SymbolicIndexingInterface.states(sys))
598+
findfirst(isequal(sym), unknown_states(sys))
585599
end
586600
function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym)
587601
!isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym))

src/systems/diffeqs/odesystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,22 @@ struct ODESystem <: AbstractODESystem
127127
"""
128128
complete::Bool
129129
"""
130-
discrete_subsystems: a list of discrete subsystems
130+
discrete_subsystems: a list of discrete subsystems.
131131
"""
132132
discrete_subsystems::Any
133+
"""
134+
unknown_states: a list of actual states needed to be solved by solvers. Only
135+
used for ODAEProblem.
136+
"""
137+
unknown_states::Union{Nothing, Vector{Any}}
133138

134139
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
135140
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
136141
torn_matching, connector_type, preface, cevents,
137142
devents, metadata = nothing, tearing_state = nothing,
138143
substitutions = nothing, complete = false,
139-
discrete_subsystems = nothing; checks::Union{Bool, Int} = true)
144+
discrete_subsystems = nothing, unknown_states = nothing;
145+
checks::Union{Bool, Int} = true)
140146
if checks == true || (checks & CheckComponents) > 0
141147
check_variables(dvs, iv)
142148
check_parameters(ps, iv)
@@ -149,7 +155,7 @@ struct ODESystem <: AbstractODESystem
149155
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
150156
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
151157
connector_type, preface, cevents, devents, metadata, tearing_state,
152-
substitutions, complete, discrete_subsystems)
158+
substitutions, complete, discrete_subsystems, unknown_states)
153159
end
154160
end
155161

0 commit comments

Comments
 (0)