Skip to content

Commit a2cb472

Browse files
fixup! feat: allow initialization of null integrators
1 parent bcaffaa commit a2cb472

File tree

1 file changed

+3
-26
lines changed

1 file changed

+3
-26
lines changed

src/initialize_dae.jl

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,7 @@ end
140140
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
141141
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
142142
initializeprob = prob.f.initializeprob
143-
if initializeprob.f.sys !== nothing && prob.f.sys !== nothing
144-
if initializeprob.u0 === nothing || isempty(initializeprob.u0)
145-
initu0 = Float64[]
146-
else
147-
initu0vars = variable_symbols(initializeprob)
148-
initu0order = variable_index.((initializeprob,), initu0vars)
149-
# Variable symbols are not guaranteed to be in order
150-
invpermute!(initu0vars, initu0order)
151-
initu0 = getu(prob.f.initializeprob, initu0vars)(prob)
152-
end
153-
initp = remake_buffer(initializeprob, parameter_values(initializeprob),
154-
Dict(sym => getu(prob, sym)(prob) for sym in parameter_symbols(initializeprob)))
155-
initializeprob = remake(initializeprob; u0 = initu0, p = initp)
156-
end
143+
prob.f.initializeprob_init!(initializeprob, integrator)
157144
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
158145
# Since then it's the case of not a DAE but has initializeprob
159146
# In which case, it should be differentiable
@@ -168,19 +155,9 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
168155
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
169156
nlsol = solve(initializeprob, alg)
170157
if isinplace === Val{true}()
171-
if prob.u0 !== nothing && !isempty(prob.u0)
172-
integrator.u .= prob.f.initializeprobmap(nlsol)
173-
end
174-
if SciMLBase.has_initializeprob_updatep(prob.f)
175-
prob.f.initializeprob_updatep!(integrator.p, nlsol)
176-
end
158+
prob.f.initializeprob_update!(integrator, nlsol)
177159
elseif isinplace === Val{false}()
178-
if prob.u0 !== nothing && !isempty(prob.u0)
179-
integrator.u .= prob.f.initializeprobmap(nlsol)
180-
end
181-
if SciMLBase.has_initializeprob_updatep(prob.f)
182-
prob.f.initializeprob_updatep!(integrator.p, nlsol)
183-
end
160+
prob.f.initializeprob_update!(integrator, nlsol)
184161
else
185162
error("Unreachable reached. Report this error.")
186163
end

0 commit comments

Comments
 (0)