Skip to content

Commit f91d674

Browse files
Merge pull request #1143 from AayushSabharwal/as/pre-solve-hook
fix: handle `u0` and `p` in `kwargs` of `get_concrete_problem`
2 parents d70c406 + fa17edc commit f91d674

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

src/solve.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,13 @@ an updated `prob` to be used for solving. All implementations should accept arbi
523523
keyword arguments.
524524
525525
Should be called before the problem is solved, after performing type-promotion on the
526-
problem.
526+
problem. If the returned problem is not `===` the provided `prob`, it is assumed to
527+
contain the `u0` and `p` passed as keyword arguments.
528+
529+
# Keyword Arguments
530+
531+
- `u0`, `p`: Override values for `state_values(prob)` and `parameter_values(prob)` which
532+
should be used instead of the ones in `prob`.
527533
"""
528534
function get_updated_symbolic_problem(indp, prob; kw...)
529535
return prob
@@ -1239,27 +1245,36 @@ function checkkwargs(kwargshandle; kwargs...)
12391245
end
12401246

12411247
function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...)
1242-
get_updated_symbolic_problem(_get_root_indp(prob), prob)
1248+
get_updated_symbolic_problem(_get_root_indp(prob), prob; kwargs...)
12431249
end
12441250

12451251
function get_concrete_problem(prob::SteadyStateProblem, isadapt; kwargs...)
1246-
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
1252+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob; kwargs...)
1253+
if prob !== prob
1254+
kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
1255+
end
12471256
p = get_concrete_p(prob, kwargs)
12481257
u0 = get_concrete_u0(prob, isadapt, Inf, kwargs)
12491258
u0 = promote_u0(u0, p, nothing)
12501259
remake(prob; u0 = u0, p = p)
12511260
end
12521261

12531262
function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...)
1254-
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
1263+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob; kwargs...)
1264+
if prob !== prob
1265+
kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
1266+
end
12551267
p = get_concrete_p(prob, kwargs)
12561268
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
12571269
u0 = promote_u0(u0, p, nothing)
12581270
remake(prob; u0 = u0, p = p)
12591271
end
12601272

12611273
function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...)
1262-
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
1274+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob; kwargs...)
1275+
if prob !== prob
1276+
kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
1277+
end
12631278
p = get_concrete_p(prob, kwargs)
12641279
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
12651280
u0 = promote_u0(u0, p, nothing)
@@ -1281,7 +1296,10 @@ function init(prob::PDEProblem, alg::AbstractDEAlgorithm, args...;
12811296
end
12821297

12831298
function get_concrete_problem(prob, isadapt; kwargs...)
1284-
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
1299+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob; kwargs...)
1300+
if prob !== prob
1301+
kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
1302+
end
12851303
p = get_concrete_p(prob, kwargs)
12861304
tspan = get_concrete_tspan(prob, isadapt, kwargs, p)
12871305
u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs)
@@ -1300,7 +1318,10 @@ function get_concrete_problem(prob, isadapt; kwargs...)
13001318
end
13011319

13021320
function get_concrete_problem(prob::DAEProblem, isadapt; kwargs...)
1303-
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
1321+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob; kwargs...)
1322+
if prob !== prob
1323+
kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
1324+
end
13041325
p = get_concrete_p(prob, kwargs)
13051326
tspan = get_concrete_tspan(prob, isadapt, kwargs, p)
13061327
u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs)
@@ -1324,7 +1345,10 @@ function get_concrete_problem(prob::DAEProblem, isadapt; kwargs...)
13241345
end
13251346

13261347
function get_concrete_problem(prob::DDEProblem, isadapt; kwargs...)
1327-
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
1348+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob; kwargs...)
1349+
if prob !== prob
1350+
kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
1351+
end
13281352
p = get_concrete_p(prob, kwargs)
13291353
tspan = get_concrete_tspan(prob, isadapt, kwargs, p)
13301354
u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs)

0 commit comments

Comments
 (0)