Skip to content

Commit 4702426

Browse files
Automate tagging of the initialization system to the ODEProblem
1 parent aafddc9 commit 4702426

File tree

4 files changed

+39
-6
lines changed

4 files changed

+39
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ PrecompileTools = "1"
9494
RecursiveArrayTools = "2.3, 3"
9595
Reexport = "0.2, 1"
9696
RuntimeGeneratedFunctions = "0.5.9"
97-
SciMLBase = "2.0.1"
97+
SciMLBase = "2.27"
9898
SciMLStructures = "1.0"
9999
Serialization = "1"
100100
Setfield = "0.7, 0.8, 1"

src/systems/abstractsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,8 @@ function complete(sys::AbstractSystem; split = true)
537537
if split && has_index_cache(sys)
538538
@set! sys.index_cache = IndexCache(sys)
539539
end
540-
if isdefined(sys, :initializationsystem)
541-
@set! sys.initializationsystem = complete(get_initializationsystem(sys); split)
540+
if isdefined(sys, :initializesystem) && get_initializesystem(sys) !== nothing
541+
@set! sys.initializesystem = complete(get_initializesystem(sys); split)
542542
end
543543
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
544544
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
315315
checkbounds = false,
316316
sparsity = false,
317317
analytic = nothing,
318-
split_idxs = nothing,
318+
split_idxs = nothing,
319+
initializeprob = nothing,
320+
initializeprobmap = nothing,
319321
kwargs...) where {iip, specialize}
320322
if !iscomplete(sys)
321323
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
@@ -487,6 +489,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
487489
end
488490

489491
@set! sys.split_idxs = split_idxs
492+
490493
ODEFunction{iip, specialize}(f;
491494
sys = sys,
492495
jac = _jac === nothing ? nothing : _jac,
@@ -495,7 +498,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
495498
jac_prototype = jac_prototype,
496499
observed = observedfun,
497500
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
498-
analytic = analytic)
501+
analytic = analytic,
502+
initializeprob = initializeprob,
503+
initializeprobmap = initializeprobmap)
499504
end
500505

501506
"""
@@ -525,6 +530,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
525530
sparse = false, simplify = false,
526531
eval_module = @__MODULE__,
527532
checkbounds = false,
533+
initializeprob = nothing,
534+
initializeprobmap = nothing,
528535
kwargs...) where {iip}
529536
if !iscomplete(sys)
530537
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
@@ -596,7 +603,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
596603
sys = sys,
597604
jac = _jac === nothing ? nothing : _jac,
598605
jac_prototype = jac_prototype,
599-
observed = observedfun)
606+
observed = observedfun,
607+
initializeprob = initializeprob,
608+
initializeprobmap = initializeprobmap)
600609
end
601610

602611
function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -877,10 +886,15 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
877886

878887
check_eqs_u0(eqs, dvs, u0; kwargs...)
879888

889+
initializeprob = ModelingToolkit.InitializationProblem(sys, u0map, parammap)
890+
initializeprobmap = getu(initializeprob, unknowns(sys))
891+
880892
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
881893
checkbounds = checkbounds, p = p,
882894
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
883895
sparse = sparse, eval_expression = eval_expression,
896+
initializeprob = initializeprob,
897+
initializeprobmap = initializeprobmap,
884898
kwargs...)
885899
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
886900
end

test/initializationsystem.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,24 @@ prob = ODEProblem(sys, allinit, (0, 0.1))
181181
sol = solve(prob, Rodas5P())
182182
# If initialized incorrectly, then it would be InitialFailure
183183
@test sol.retcode == SciMLBase.ReturnCode.Unstable
184+
SciMLBase.has_initializeprob(prob.f)
185+
186+
isys = ModelingToolkit.get_initializesystem(sys)
187+
unknowns(isys)
188+
189+
initprob = ModelingToolkit.InitializationProblem(sys)
190+
sol = solve(initprob)
191+
192+
unknowns(sys)
193+
194+
[sys.act.vol₁.dr]
195+
196+
getter = ModelingToolkit.getu(initprob, unknowns(sys)[end-1:end])
197+
getter(sol)
198+
199+
prob.f.initializeprobmap(initsol)
200+
201+
initsol[unknowns(isys)]
184202

185203
@connector Flange begin
186204
dx(t), [guess = 0]
@@ -250,3 +268,4 @@ prob = ODEProblem(sys, allinit, (0, 0.1))
250268
sol = solve(prob, Rodas5P())
251269
# If initialized incorrectly, then it would be InitialFailure
252270
@test sol.retcode == SciMLBase.ReturnCode.Success
271+

0 commit comments

Comments
 (0)