Skip to content

Commit 2698015

Browse files
fix: handle remake with no pre-existing initializeprob
1 parent 2641fd8 commit 2698015

File tree

1 file changed

+47
-11
lines changed

1 file changed

+47
-11
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -210,17 +210,16 @@ function is_parameter_solvable(p, pmap, defs, guesses)
210210
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
211211
end
212212

213-
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
213+
function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newu0, newp)
214214
if u0 === missing && p === missing
215-
return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap,
216-
odefn.initializeprobpmap
215+
return odefn.initialization_data
217216
end
218217
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
219218
oldinitprob = odefn.initializeprob
220-
if oldinitprob === nothing || !SciMLBase.has_sys(oldinitprob.f) ||
221-
!(oldinitprob.f.sys isa NonlinearSystem)
222-
return oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap,
223-
odefn.initializeprobpmap
219+
oldinitprob === nothing && return nothing
220+
if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem)
221+
return SciMLBase.OverrideInitData(oldinitprob, odefn.update_initializeprob!,
222+
odefn.initializeprobmap, odefn.initializeprobpmap)
224223
end
225224
pidxs = ParameterIndex[]
226225
pvals = []
@@ -262,14 +261,17 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
262261
oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
263262
end
264263
initprob = remake(oldinitprob; u0 = newu0, p = newp)
265-
return initprob, odefn.update_initializeprob!, odefn.initializeprobmap,
266-
odefn.initializeprobpmap
264+
return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!,
265+
odefn.initializeprobmap, odefn.initializeprobpmap)
267266
end
268267
dvs = unknowns(sys)
269268
ps = parameters(sys)
270269
u0map = to_varmap(u0, dvs)
270+
symbols_to_symbolics!(sys, u0map)
271271
pmap = to_varmap(p, ps)
272+
symbols_to_symbolics!(sys, pmap)
272273
guesses = Dict()
274+
defs = defaults(sys)
273275
if SciMLBase.has_initializeprob(odefn)
274276
oldsys = odefn.initializeprob.f.sys
275277
meta = get_metadata(oldsys)
@@ -278,6 +280,35 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
278280
pmap = merge(meta.pmap, pmap)
279281
merge!(guesses, meta.additional_guesses)
280282
end
283+
else
284+
# there is no initializeprob, so the original problem construction
285+
# had no solvable parameters and had the differential variables
286+
# specified in `u0map`.
287+
if u0 === missing
288+
# the user didn't pass `u0` to `remake`, so they want to retain
289+
# existing values. Fill the differential variables in `u0map`,
290+
# initialization will either be elided or solve for the algebraic
291+
# variables
292+
diff_idxs = isdiffeq.(equations(sys))
293+
for i in eachindex(dvs)
294+
diff_idxs[i] || continue
295+
u0map[dvs[i]] = newu0[i]
296+
end
297+
end
298+
if p === missing
299+
# the user didn't pass `p` to `remake`, so they want to retain
300+
# existing values. Fill all parameters in `pmap` so that none of
301+
# them are solvable.
302+
for p in ps
303+
pmap[p] = getp(sys, p)(newp)
304+
end
305+
end
306+
# all non-solvable parameters need values regardless
307+
for p in ps
308+
haskey(pmap, p) && continue
309+
is_parameter_solvable(p, pmap, defs, guesses) && continue
310+
pmap[p] = getp(sys, p)(newp)
311+
end
281312
end
282313
if t0 === nothing
283314
t0 = 0.0
@@ -286,8 +317,13 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
286317
filter_missing_values!(pmap)
287318
f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0)
288319
kws = f.kwargs
289-
return get(kws, :initializeprob, nothing), get(kws, :update_initializeprob!, nothing), get(kws, :initializeprobmap, nothing),
290-
get(kws, :initializeprobpmap, nothing)
320+
initprob = get(kws, :initializeprob, nothing)
321+
if initprob === nothing
322+
return nothing
323+
end
324+
return SciMLBase.OverrideInitData(initprob, get(kws, :update_initializeprob!, nothing),
325+
get(kws, :initializeprobmap, nothing),
326+
get(kws, :initializeprobpmap, nothing))
291327
end
292328

293329
"""

0 commit comments

Comments
 (0)