Skip to content

Commit 8d15812

Browse files
chore: add extra indirection to pick correct nlsol
1 parent b256da4 commit 8d15812

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

src/initialization.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
248248
return u0, p, true
249249
end
250250

251-
initdata::OverrideInitData = ChainRulesCore.@ignore_derivatives f.initialization_data
251+
initdata::OverrideInitData = f.initialization_data
252252
initprob = initdata.initializeprob
253253

254254
if initdata.update_initializeprob! !== nothing
@@ -260,7 +260,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
260260
end
261261

262262
if is_trivial_initialization(initdata)
263-
nlsol = initprob
263+
nlsol = initdata
264264
success = true
265265
else
266266
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
@@ -294,11 +294,11 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
294294
end
295295
end
296296

297-
u0 = if initdata.initializeprobmap !== nothing
298-
initdata.initializeprobmap(nlsol)
297+
if initdata.initializeprobmap !== nothing
298+
u0 = initdata.initializeprobmap(choose_branch(nlsol))
299299
end
300-
p = if initdata.initializeprobpmap !== nothing
301-
initdata.initializeprobpmap(valp, nlsol)
300+
if initdata.initializeprobpmap !== nothing
301+
p = initdata.initializeprobpmap(valp, choose_branch(nlsol))
302302
end
303303

304304
return u0, p, success

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,6 @@ Strips a SciMLSolution object and its interpolation of their functions to better
550550
function strip_solution(sol::AbstractSciMLSolution)
551551
sol
552552
end
553+
554+
choose_branch(x::OverrideInitData) = x.initializeprob
555+
choose_branch(sol::AbstractSciMLSolution) = sol

0 commit comments

Comments
 (0)