Skip to content

Commit 77fff1e

Browse files
Merge pull request #838 from AayushSabharwal/as/initdata
refactor: use `initialization_data` instead of `initializeprob`, etc.
2 parents babd239 + 55c171d commit 77fff1e

File tree

6 files changed

+166
-126
lines changed

6 files changed

+166
-126
lines changed

src/SciMLBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,7 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
654654
struct TrackerOriginator <: ADOriginator end
655655

656656
include("utils.jl")
657+
include("initialization.jl")
657658
include("function_wrappers.jl")
658659
include("scimlfunctions.jl")
659660
include("alg_traits.jl")

src/initialization.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
$(TYPEDEF)
3+
4+
A collection of all the data required for `OverrideInit`.
5+
"""
6+
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
7+
"""
8+
The `AbstractNonlinearProblem` to solve for initialization.
9+
"""
10+
initializeprob::IProb
11+
"""
12+
A function which takes `(initializeprob, prob)` and updates
13+
the parameters of the former with their values in the latter.
14+
"""
15+
update_initializeprob!::UIProb
16+
"""
17+
A function which takes the solution of `initializeprob` and returns
18+
the state vector of the original problem.
19+
"""
20+
initializeprobmap::IProbMap
21+
"""
22+
A function which takes the solution of `initializeprob` and returns
23+
the parameter object of the original problem.
24+
"""
25+
initializeprobpmap::IProbPmap
26+
27+
function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
28+
initprobpmap::L) where {I, J, K, L}
29+
@assert initprob isa Union{NonlinearProblem, NonlinearLeastSquaresProblem}
30+
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
31+
end
32+
end

src/remake.jl

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ function remake(prob::ODEProblem; f = missing,
125125

126126
if f === missing
127127
if build_initializeprob
128-
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob(
128+
initialization_data = remake_initialization_data(
129129
prob.f.sys, prob.f, u0, tspan[1], p)
130130
else
131-
initializeprob = update_initializeprob! = initializeprobmap = initializeprobpmap = nothing
131+
initialization_data = nothing
132132
end
133133
if specialization(prob.f) === FunctionWrapperSpecialize
134134
ptspan = promote_tspan(tspan)
@@ -137,45 +137,21 @@ function remake(prob::ODEProblem; f = missing,
137137
wrapfun_iip(
138138
unwrapped_f(prob.f.f),
139139
(newu0, newu0, newp,
140-
ptspan[1]));
141-
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
140+
ptspan[1])); initialization_data)
142141
else
143142
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
144143
wrapfun_oop(
145144
unwrapped_f(prob.f.f),
146145
(newu0, newp,
147-
ptspan[1]));
148-
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
146+
ptspan[1])); initialization_data)
149147
end
150148
else
151149
_f = prob.f
152-
if __has_initializeprob(_f)
150+
if __has_initialization_data(_f)
153151
props = getproperties(_f)
154-
@reset props.initializeprob = initializeprob
152+
@reset props.initialization_data = initialization_data
155153
props = values(props)
156-
_f = parameterless_type(_f){
157-
iip, specialization(_f), map(typeof, props)...}(props...)
158-
end
159-
if __has_update_initializeprob!(_f)
160-
props = getproperties(_f)
161-
@reset props.update_initializeprob! = update_initializeprob!
162-
props = values(props)
163-
_f = parameterless_type(_f){
164-
iip, specialization(_f), map(typeof, props)...}(props...)
165-
end
166-
if __has_initializeprobmap(_f)
167-
props = getproperties(_f)
168-
@reset props.initializeprobmap = initializeprobmap
169-
props = values(props)
170-
_f = parameterless_type(_f){
171-
iip, specialization(_f), map(typeof, props)...}(props...)
172-
end
173-
if __has_initializeprobpmap(_f)
174-
props = getproperties(_f)
175-
@reset props.initializeprobpmap = initializeprobpmap
176-
props = values(props)
177-
_f = parameterless_type(_f){
178-
iip, specialization(_f), map(typeof, props)...}(props...)
154+
_f = parameterless_type(_f){iip, specialization(_f), map(typeof, props)...}(props...)
179155
end
180156
end
181157
elseif f isa AbstractODEFunction
@@ -206,6 +182,9 @@ end
206182
"""
207183
remake_initializeprob(sys, scimlfn, u0, t0, p)
208184
185+
!! WARN
186+
This method is deprecated. Please see `remake_initialization_data`
187+
209188
Re-create the initialization problem present in the function `scimlfn`, using the
210189
associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and
211190
`p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an
@@ -223,6 +202,21 @@ function remake_initializeprob(sys, scimlfn, u0, t0, p)
223202
scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap
224203
end
225204

205+
"""
206+
remake_initialization_data(sys, scimlfn, u0, t0, p)
207+
208+
Re-create the initialization data present in the function `scimlfn`, using the
209+
associated system `sys` and the user provided new values of `u0`, initial time `t0` and
210+
`p`. By default, this calls `remake_initializeprob` for backward compatibility and
211+
attempts to construct an `OverrideInitData` from the result.
212+
213+
Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
214+
"""
215+
function remake_initialization_data(sys, scimlfn, u0, t0, p)
216+
return reconstruct_initialization_data(
217+
nothing, remake_initializeprob(sys, scimlfn, u0, t0, p)...)
218+
end
219+
226220
"""
227221
remake(prob::BVProblem; f = missing, u0 = missing, tspan = missing,
228222
p = missing, kwargs = missing, problem_type = missing, _kwargs...)

0 commit comments

Comments
 (0)