Skip to content

Commit f0720b7

Browse files
feat: allow initialization of null integrators
1 parent e3b7d7a commit f0720b7

File tree

1 file changed

+34
-9
lines changed

1 file changed

+34
-9
lines changed

src/initialize_dae.jl

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ function DiffEqBase.initialize_dae!(integrator::ODEIntegrator,
7171
Val(DiffEqBase.isinplace(integrator.sol.prob)))
7272
end
7373

74+
function DiffEqBase.initialize_dae!(integrator::DiffEqBase.NullODEIntegrator)
75+
_initialize_dae!(integrator, integrator.sol.prob,
76+
OverrideInit(),
77+
Val(DiffEqBase.isinplace(integrator.sol.prob)))
78+
end
79+
7480
## Default algorithms
7581

7682
function _initialize_dae!(integrator, prob::ODEProblem,
@@ -135,27 +141,46 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
135141
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
136142
initializeprob = prob.f.initializeprob
137143
if initializeprob.f.sys !== nothing && prob.f.sys !== nothing
138-
initu0vars = variable_symbols(initializeprob)
139-
initu0order = variable_index.((initializeprob,), initu0vars)
140-
# Variable symbols are not guaranteed to be in order
141-
invpermute!(initu0vars, initu0order)
142-
initu0 = getu(prob.f.initializeprob, initu0vars)(prob)
144+
if initializeprob.u0 === nothing || isempty(initializeprob.u0)
145+
initu0 = Float64[]
146+
else
147+
initu0vars = variable_symbols(initializeprob)
148+
initu0order = variable_index.((initializeprob,), initu0vars)
149+
# Variable symbols are not guaranteed to be in order
150+
invpermute!(initu0vars, initu0order)
151+
initu0 = getu(prob.f.initializeprob, initu0vars)(prob)
152+
end
143153
initp = remake_buffer(initializeprob, parameter_values(initializeprob),
144154
Dict(sym => getu(prob, sym)(prob) for sym in parameter_symbols(initializeprob)))
145155
initializeprob = remake(initializeprob; u0 = initu0, p = initp)
146156
end
147157
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
148158
# Since then it's the case of not a DAE but has initializeprob
149159
# In which case, it should be differentiable
150-
isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff :
151-
true
160+
isAD = if !isdefined(integrator, :alg)
161+
false
162+
elseif has_autodiff(integrator.alg)
163+
alg_autodiff(integrator.alg) isa AutoForwardDiff
164+
else
165+
true
166+
end
152167

153168
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
154169
nlsol = solve(initializeprob, alg)
155170
if isinplace === Val{true}()
156-
integrator.u .= prob.f.initializeprobmap(nlsol)
171+
if prob.u0 !== nothing && !isempty(prob.u0)
172+
integrator.u .= prob.f.initializeprobmap(nlsol)
173+
end
174+
if SciMLBase.has_initializeprob_updatep(prob.f)
175+
prob.f.initializeprob_updatep!(integrator.p, nlsol)
176+
end
157177
elseif isinplace === Val{false}()
158-
integrator.u = prob.f.initializeprobmap(nlsol)
178+
if prob.u0 !== nothing && !isempty(prob.u0)
179+
integrator.u .= prob.f.initializeprobmap(nlsol)
180+
end
181+
if SciMLBase.has_initializeprob_updatep(prob.f)
182+
prob.f.initializeprob_updatep!(integrator.p, nlsol)
183+
end
159184
else
160185
error("Unreachable reached. Report this error.")
161186
end

0 commit comments

Comments
 (0)