Skip to content

Commit 186302f

Browse files
late binding initialization_eqs
1 parent a905587 commit 186302f

File tree

5 files changed

+35
-6
lines changed

5 files changed

+35
-6
lines changed

src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ for prop in [:eqs
576576
:preface
577577
:torn_matching
578578
:initializesystem
579+
:initialization_eqs
579580
:schedule
580581
:tearing_state
581582
:substitutions

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
864864

865865
# TODO: Pass already computed information to varmap_to_vars call
866866
# in process_u0? That would just be a small optimization
867-
varmap = isempty(u0map) ? defaults : merge(defaults, todict(u0map))
867+
varmap = isempty(u0map) ? defaults(sys) : merge(defaults(sys), todict(u0map))
868868
varlist = collect(map(unwrap, dvs))
869869
missingvars = setdiff(varlist, collect(keys(varmap)))
870870

src/systems/diffeqs/odesystem.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ struct ODESystem <: AbstractODESystem
101101
"""
102102
initializesystem::Union{Nothing, NonlinearSystem}
103103
"""
104+
Extra equations to be enforced during the initialization sequence.
105+
"""
106+
initialization_eqs::Vector{Equation}
107+
"""
104108
The schedule for the code generation process.
105109
"""
106110
schedule::Any
@@ -171,7 +175,8 @@ struct ODESystem <: AbstractODESystem
171175

172176
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
173177
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses,
174-
torn_matching, initializesystem, schedule, connector_type, preface, cevents,
178+
torn_matching, initializesystem, initialization_eqs, schedule,
179+
connector_type, preface, cevents,
175180
devents, parameter_dependencies,
176181
metadata = nothing, gui_metadata = nothing,
177182
tearing_state = nothing,
@@ -190,8 +195,8 @@ struct ODESystem <: AbstractODESystem
190195
end
191196
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
192197
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, torn_matching,
193-
initializesystem, schedule, connector_type, preface, cevents, devents, parameter_dependencies,
194-
metadata,
198+
initializesystem, initialization_eqs, schedule, connector_type, preface,
199+
cevents, devents, parameter_dependencies, metadata,
195200
gui_metadata, tearing_state, substitutions, complete, index_cache,
196201
discrete_subsystems, solved_unknowns, split_idxs, parent)
197202
end
@@ -208,6 +213,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
208213
defaults = _merge(Dict(default_u0), Dict(default_p)),
209214
guesses = Dict(),
210215
initializesystem = nothing,
216+
initialization_eqs = Equation[],
211217
schedule = nothing,
212218
connector_type = nothing,
213219
preface = nothing,
@@ -260,7 +266,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
260266
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
261267
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
262268
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, nothing, initializesystem,
263-
schedule, connector_type, preface, cont_callbacks, disc_callbacks, parameter_dependencies,
269+
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
270+
disc_callbacks, parameter_dependencies,
264271
metadata, gui_metadata, checks = checks)
265272
end
266273

src/systems/nonlinear/initializesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function generate_initializesystem(sys::ODESystem;
5555
end
5656

5757
pars = [parameters(sys); get_iv(sys)]
58-
nleqs = [eqs_ics; observed(sys)]
58+
nleqs = [eqs_ics; get_initialization_eqs(sys); observed(sys)]
5959

6060
sys_nl = NonlinearSystem(nleqs,
6161
full_states,

test/initializationsystem.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,24 @@ prob = ODEProblem(sys, [sys.ddx => -2], (0, 1), guesses = [sys.dx => 1])
355355
sol = solve(prob, Tsit5())
356356
@test SciMLBase.successful_retcode(sol)
357357
@test sol[1] == [1.0]
358+
359+
## Late binding initialization_eqs
360+
361+
function System(; name)
362+
vars = @variables begin
363+
dx(t), [guess = 0]
364+
ddx(t), [guess = 0]
365+
end
366+
eqs = [D(dx) ~ ddx
367+
0 ~ ddx + dx + 1]
368+
initialization_eqs = [
369+
ddx ~ -2
370+
]
371+
return ODESystem(eqs, t, vars, []; name, initialization_eqs)
372+
end
373+
374+
@mtkbuild sys = System()
375+
prob = ODEProblem(sys, [], (0, 1), guesses = [sys.dx => 1])
376+
sol = solve(prob, Tsit5())
377+
@test SciMLBase.successful_retcode(sol)
378+
@test sol[1] == [1.0]

0 commit comments

Comments
 (0)