Skip to content

Commit 2c8d016

Browse files
fix: handle remake with no pre-existing initializeprob
1 parent d487fc8 commit 2c8d016

File tree

1 file changed

+42
-9
lines changed

1 file changed

+42
-9
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,17 +210,15 @@ function is_parameter_solvable(p, pmap, defs, guesses)
210210
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
211211
end
212212

213-
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
213+
function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newu0, newp)
214214
if u0 === missing && p === missing
215-
return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap,
216-
odefn.initializeprobpmap
215+
return SciMLBase.OverrideInitData(odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap, odefn.initializeprobpmap)
217216
end
218217
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
219218
oldinitprob = odefn.initializeprob
220219
if oldinitprob === nothing || !SciMLBase.has_sys(oldinitprob.f) ||
221220
!(oldinitprob.f.sys isa NonlinearSystem)
222-
return oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap,
223-
odefn.initializeprobpmap
221+
return SciMLBase.OverrideInitData(oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap, odefn.initializeprobpmap)
224222
end
225223
pidxs = ParameterIndex[]
226224
pvals = []
@@ -262,14 +260,16 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
262260
oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
263261
end
264262
initprob = remake(oldinitprob; u0 = newu0, p = newp)
265-
return initprob, odefn.update_initializeprob!, odefn.initializeprobmap,
266-
odefn.initializeprobpmap
263+
return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!, odefn.initializeprobmap, odefn.initializeprobpmap)
267264
end
268265
dvs = unknowns(sys)
269266
ps = parameters(sys)
270267
u0map = to_varmap(u0, dvs)
268+
symbols_to_symbolics!(sys, u0map)
271269
pmap = to_varmap(p, ps)
270+
symbols_to_symbolics!(sys, pmap)
272271
guesses = Dict()
272+
defs = defaults(sys)
273273
if SciMLBase.has_initializeprob(odefn)
274274
oldsys = odefn.initializeprob.f.sys
275275
meta = get_metadata(oldsys)
@@ -278,6 +278,35 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
278278
pmap = merge(meta.pmap, pmap)
279279
merge!(guesses, meta.additional_guesses)
280280
end
281+
else
282+
# there is no initializeprob, so the original problem construction
283+
# had no solvable parameters and had the differential variables
284+
# specified in `u0map`.
285+
if u0 === missing
286+
# the user didn't pass `u0` to `remake`, so they want to retain
287+
# existing values. Fill the differential variables in `u0map`,
288+
# initialization will either be elided or solve for the algebraic
289+
# variables
290+
diff_idxs = isdiffeq.(equations(sys))
291+
for i in eachindex(dvs)
292+
diff_idxs[i] || continue
293+
u0map[dvs[i]] = newu0[i]
294+
end
295+
end
296+
if p === missing
297+
# the user didn't pass `p` to `remake`, so they want to retain
298+
# existing values. Fill all parameters in `pmap` so that none of
299+
# them are solvable.
300+
for p in ps
301+
pmap[p] = getp(sys, p)(newp)
302+
end
303+
end
304+
# all non-solvable parameters need values regardless
305+
for p in ps
306+
haskey(pmap, p) && continue
307+
is_parameter_solvable(p, pmap, defs, guesses) && continue
308+
pmap[p] = getp(sys, p)(newp)
309+
end
281310
end
282311
if t0 === nothing
283312
t0 = 0.0
@@ -286,8 +315,12 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
286315
filter_missing_values!(pmap)
287316
f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0)
288317
kws = f.kwargs
289-
return get(kws, :initializeprob, nothing), get(kws, :update_initializeprob!, nothing), get(kws, :initializeprobmap, nothing),
290-
get(kws, :initializeprobpmap, nothing)
318+
initprob = get(kws, :initializeprob, nothing)
319+
if initprob === nothing
320+
return nothing
321+
end
322+
return SciMLBase.OverrideInitData(initprob, get(kws, :update_initializeprob!, nothing), get(kws, :initializeprobmap, nothing),
323+
get(kws, :initializeprobpmap, nothing))
291324
end
292325

293326
"""

0 commit comments

Comments
 (0)