Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
struct TrackerOriginator <: ADOriginator end

include("utils.jl")
include("initialization.jl")
include("function_wrappers.jl")
include("scimlfunctions.jl")
include("alg_traits.jl")
Expand Down
32 changes: 32 additions & 0 deletions src/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
$(TYPEDEF)

A collection of all the data required for `OverrideInit`.
"""
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
"""
The `AbstractNonlinearProblem` to solve for initialization.
"""
initializeprob::IProb
"""
A function which takes `(initializeprob, prob)` and updates
the parameters of the former with their values in the latter.
"""
update_initializeprob!::UIProb
"""
A function which takes the solution of `initializeprob` and returns
the state vector of the original problem.
"""
initializeprobmap::IProbMap
"""
A function which takes the solution of `initializeprob` and returns
the parameter object of the original problem.
"""
initializeprobpmap::IProbPmap

function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
initprobpmap::L) where {I, J, K, L}
@assert initprob isa Union{NonlinearProblem, NonlinearLeastSquaresProblem}
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
end
end
56 changes: 25 additions & 31 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ function remake(prob::ODEProblem; f = missing,

if f === missing
if build_initializeprob
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob(
initialization_data = remake_initialization_data(
prob.f.sys, prob.f, u0, tspan[1], p)
else
initializeprob = update_initializeprob! = initializeprobmap = initializeprobpmap = nothing
initialization_data = nothing
end
if specialization(prob.f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
Expand All @@ -137,45 +137,21 @@ function remake(prob::ODEProblem; f = missing,
wrapfun_iip(
unwrapped_f(prob.f.f),
(newu0, newu0, newp,
ptspan[1]));
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
ptspan[1])); initialization_data)
else
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
wrapfun_oop(
unwrapped_f(prob.f.f),
(newu0, newp,
ptspan[1]));
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
ptspan[1])); initialization_data)
end
else
_f = prob.f
if __has_initializeprob(_f)
if __has_initialization_data(_f)
props = getproperties(_f)
@reset props.initializeprob = initializeprob
@reset props.initialization_data = initialization_data
props = values(props)
_f = parameterless_type(_f){
iip, specialization(_f), map(typeof, props)...}(props...)
end
if __has_update_initializeprob!(_f)
props = getproperties(_f)
@reset props.update_initializeprob! = update_initializeprob!
props = values(props)
_f = parameterless_type(_f){
iip, specialization(_f), map(typeof, props)...}(props...)
end
if __has_initializeprobmap(_f)
props = getproperties(_f)
@reset props.initializeprobmap = initializeprobmap
props = values(props)
_f = parameterless_type(_f){
iip, specialization(_f), map(typeof, props)...}(props...)
end
if __has_initializeprobpmap(_f)
props = getproperties(_f)
@reset props.initializeprobpmap = initializeprobpmap
props = values(props)
_f = parameterless_type(_f){
iip, specialization(_f), map(typeof, props)...}(props...)
_f = parameterless_type(_f){iip, specialization(_f), map(typeof, props)...}(props...)
end
end
elseif f isa AbstractODEFunction
Expand Down Expand Up @@ -206,6 +182,9 @@ end
"""
remake_initializeprob(sys, scimlfn, u0, t0, p)

!! WARN
This method is deprecated. Please see `remake_initialization_data`

Re-create the initialization problem present in the function `scimlfn`, using the
associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and
`p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an
Expand All @@ -223,6 +202,21 @@ function remake_initializeprob(sys, scimlfn, u0, t0, p)
scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap
end

"""
remake_initialization_data(sys, scimlfn, u0, t0, p)

Re-create the initialization data present in the function `scimlfn`, using the
associated system `sys` and the user provided new values of `u0`, initial time `t0` and
`p`. By default, this calls `remake_initializeprob` for backward compatibility and
attempts to construct an `OverrideInitData` from the result.

Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
"""
function remake_initialization_data(sys, scimlfn, u0, t0, p)
return reconstruct_initialization_data(
nothing, remake_initializeprob(sys, scimlfn, u0, t0, p)...)
end

"""
remake(prob::BVProblem; f = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, _kwargs...)
Expand Down
Loading
Loading