Skip to content

Commit f69ad5a

Browse files
feat: add update_initializeprob! to relevant SciMLFunctions, remake
1 parent bad69a2 commit f69ad5a

File tree

2 files changed

+64
-33
lines changed

2 files changed

+64
-33
lines changed

src/remake.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function remake(prob::ODEProblem; f = missing,
123123
iip = isinplace(prob)
124124

125125
if f === missing
126-
initializeprob, initializeprobmap, initializeprobpmap = remake_initializeprob(
126+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob(
127127
prob.f.sys, prob.f, u0, tspan[1], p)
128128
if specialization(prob.f) === FunctionWrapperSpecialize
129129
ptspan = promote_tspan(tspan)
@@ -133,14 +133,14 @@ function remake(prob::ODEProblem; f = missing,
133133
unwrapped_f(prob.f.f),
134134
(newu0, newu0, newp,
135135
ptspan[1]));
136-
initializeprob, initializeprobmap, initializeprobpmap)
136+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
137137
else
138138
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
139139
wrapfun_oop(
140140
unwrapped_f(prob.f.f),
141141
(newu0, newp,
142142
ptspan[1]));
143-
initializeprob, initializeprobmap, initializeprobpmap)
143+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
144144
end
145145
else
146146
_f = prob.f
@@ -151,6 +151,13 @@ function remake(prob::ODEProblem; f = missing,
151151
_f = parameterless_type(_f){
152152
iip, specialization(_f), map(typeof, props)...}(props...)
153153
end
154+
if __has_update_initializeprob!(_f)
155+
props = getproperties(_f)
156+
@reset props.update_initializeprob! = update_initializeprob!
157+
props = values(props)
158+
_f = parameterless_type(_f){
159+
iip, specialization(_f), map(typeof, props)...}(props...)
160+
end
154161
if __has_initializeprobmap(_f)
155162
props = getproperties(_f)
156163
@reset props.initializeprobmap = initializeprobmap
@@ -196,18 +203,19 @@ end
196203
197204
Re-create the initialization problem present in the function `scimlfn`, using the
198205
associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and
199-
`p`. By default, returns `nothing, nothing, nothing` if `scimlfn` does not have an
206+
`p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an
200207
initialization problem, and
201-
`scimlfn.initializeprob, scimlfn.initializeprobmap, scimlfn.initializeprobpmap` if it
202-
does.
208+
`scimlfn.initializeprob, scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap`
209+
if it does.
203210
204211
Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
205212
"""
206213
function remake_initializeprob(sys, scimlfn, u0, t0, p)
207214
if !has_initializeprob(scimlfn)
208-
return nothing, nothing, nothing
215+
return nothing, nothing, nothing, nothing
209216
end
210-
return scimlfn.initializeprob, scimlfn.initializeprobmap, scimlfn.initializeprobpmap
217+
return scimlfn.initializeprob,
218+
scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap
211219
end
212220

213221
"""
@@ -703,7 +711,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
703711
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
704712
end
705713

706-
function updated_u0_p(prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
714+
function updated_u0_p(
715+
prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
707716
if u0 === missing && p === missing
708717
return state_values(prob), parameter_values(prob)
709718
end

src/scimlfunctions.jl

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ numerically-defined functions.
402402
"""
403403
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404404
O, TCV,
405-
SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
405+
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
406406
f::F
407407
mass_matrix::TMM
408408
analytic::Ta
@@ -420,6 +420,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
420420
colorvec::TCV
421421
sys::SYS
422422
initializeprob::IProb
423+
update_initializeprob!::UIProb
423424
initializeprobmap::IProbMap
424425
initializeprobpmap::IProbPmap
425426
end
@@ -519,7 +520,7 @@ information on generating the SplitFunction from this symbolic engine.
519520
struct SplitFunction{
520521
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
521522
TPJ, O,
522-
TCV, SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
523+
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
523524
f1::F1
524525
f2::F2
525526
mass_matrix::TMM
@@ -538,6 +539,7 @@ struct SplitFunction{
538539
colorvec::TCV
539540
sys::SYS
540541
initializeprob::IProb
542+
update_initializeprob!::UIProb
541543
initializeprobmap::IProbMap
542544
initializeprobpmap::IProbPmap
543545
end
@@ -1508,7 +1510,7 @@ automatically symbolically generating the Jacobian and more from the
15081510
numerically-defined functions.
15091511
"""
15101512
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
1511-
SYS, IProb, IProbMap, IProbPmap} <:
1513+
SYS, IProb, UIProb, IProbMap, IProbPmap} <:
15121514
AbstractDAEFunction{iip}
15131515
f::F
15141516
analytic::Ta
@@ -1525,6 +1527,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
15251527
colorvec::TCV
15261528
sys::SYS
15271529
initializeprob::IProb
1530+
update_initializeprob!::UIProb
15281531
initializeprobmap::IProbMap
15291532
initializeprobpmap::IProbPmap
15301533
end
@@ -2413,6 +2416,8 @@ function ODEFunction{iip, specialize}(f;
24132416
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
24142417
sys = __has_sys(f) ? f.sys : nothing,
24152418
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
2419+
update_initializeprob! = __has_update_initializeprob!(f) ?
2420+
f.update_initializeprob! : nothing,
24162421
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
24172422
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
24182423
) where {iip,
@@ -2472,10 +2477,10 @@ function ODEFunction{iip, specialize}(f;
24722477
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
24732478
Any,
24742479
typeof(_colorvec),
2475-
typeof(sys), Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2480+
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
24762481
jvp, vjp, jac_prototype, sparsity, Wfact,
24772482
Wfact_t, W_prototype, paramjac,
2478-
observed, _colorvec, sys, initializeprob, initializeprobmap,
2483+
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
24792484
initializeprobpmap)
24802485
elseif specialize === false
24812486
ODEFunction{iip, FunctionWrapperSpecialize,
@@ -2485,12 +2490,12 @@ function ODEFunction{iip, specialize}(f;
24852490
typeof(paramjac),
24862491
typeof(observed),
24872492
typeof(_colorvec),
2488-
typeof(sys), typeof(initializeprob),
2493+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
24892494
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
24902495
analytic, tgrad, jac,
24912496
jvp, vjp, jac_prototype, sparsity, Wfact,
24922497
Wfact_t, W_prototype, paramjac,
2493-
observed, _colorvec, sys, initializeprob, initializeprobmap,
2498+
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
24942499
initializeprobpmap)
24952500
else
24962501
ODEFunction{iip, specialize,
@@ -2500,11 +2505,12 @@ function ODEFunction{iip, specialize}(f;
25002505
typeof(paramjac),
25012506
typeof(observed),
25022507
typeof(_colorvec),
2503-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2508+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2509+
typeof(initializeprobmap),
25042510
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
25052511
jvp, vjp, jac_prototype, sparsity, Wfact,
25062512
Wfact_t, W_prototype, paramjac,
2507-
observed, _colorvec, sys, initializeprob, initializeprobmap,
2513+
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
25082514
initializeprobpmap)
25092515
end
25102516
end
@@ -2522,10 +2528,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25222528
Any, Any, Any, Any, typeof(f.jac_prototype),
25232529
typeof(f.sparsity), Any, Any, Any,
25242530
Any, typeof(f.colorvec),
2525-
typeof(f.sys), Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2531+
typeof(f.sys), Any, Any, Any, Any}(
2532+
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25262533
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25272534
f.Wfact_t, f.W_prototype, f.paramjac,
2528-
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
2535+
f.observed, f.colorvec, f.sys, f.initializeprob,
2536+
f.update_initializeprob!, f.initializeprobmap,
25292537
f.initializeprobpmap)
25302538
else
25312539
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
@@ -2534,11 +2542,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25342542
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
25352543
typeof(f.paramjac),
25362544
typeof(f.observed), typeof(f.colorvec),
2537-
typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprobmap),
2545+
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
2546+
typeof(f.initializeprobmap),
25382547
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25392548
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25402549
f.Wfact_t, f.W_prototype, f.paramjac,
2541-
f.observed, f.colorvec, f.sys, f.initializeprob,
2550+
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
25422551
f.initializeprobmap, f.initializeprobpmap)
25432552
end
25442553
end
@@ -2641,7 +2650,8 @@ end
26412650

26422651
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
26432652
vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac,
2644-
observed, colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
2653+
observed, colorvec, sys, initializeprob, update_initializeprob!,
2654+
initializeprobmap, initializeprobpmap)
26452655
f1 = ODEFunction(f1)
26462656
f2 = ODEFunction(f2)
26472657

@@ -2655,12 +2665,12 @@ end
26552665
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
26562666
typeof(vjp), typeof(jac_prototype), typeof(sparsity),
26572667
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
2658-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2668+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap),
26592669
typeof(initializeprobpmap)}(
26602670
f1, f2, mass_matrix,
26612671
cache, analytic, tgrad, jac, jvp, vjp,
26622672
jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2663-
initializeprob, initializeprobmap, initializeprobpmap)
2673+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
26642674
end
26652675
function SplitFunction{iip, specialize}(f1, f2;
26662676
mass_matrix = __has_mass_matrix(f1) ?
@@ -2690,6 +2700,8 @@ function SplitFunction{iip, specialize}(f1, f2;
26902700
nothing,
26912701
sys = __has_sys(f1) ? f1.sys : nothing,
26922702
initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing,
2703+
update_initializeprob! = __has_update_initializeprob!(f1) ?
2704+
f1.update_initializeprob! : nothing,
26932705
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
26942706
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing
26952707
) where {iip,
@@ -2701,12 +2713,12 @@ function SplitFunction{iip, specialize}(f1, f2;
27012713

27022714
if specialize === NoSpecialize
27032715
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
2704-
Any, Any, Any, Any, Any,
2716+
Any, Any, Any, Any, Any, Any,
27052717
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
27062718
analytic,
27072719
tgrad, jac, jvp, vjp, jac_prototype,
27082720
sparsity, Wfact, Wfact_t, paramjac,
2709-
observed, colorvec, sys, initializeprob, initializeprobmap,
2721+
observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap,
27102722
initializeprobpmap, initializeprobpmap)
27112723
else
27122724
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
@@ -2715,12 +2727,13 @@ function SplitFunction{iip, specialize}(f1, f2;
27152727
typeof(jac_prototype), typeof(sparsity),
27162728
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
27172729
typeof(colorvec),
2718-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2730+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2731+
typeof(initializeprobmap),
27192732
typeof(initializeprobpmap)}(f1, f2,
27202733
mass_matrix, _func_cache, analytic, tgrad, jac,
27212734
jvp, vjp, jac_prototype,
27222735
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2723-
initializeprob, initializeprobmap, initializeprobpmap)
2736+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
27242737
end
27252738
end
27262739

@@ -3346,6 +3359,8 @@ function DAEFunction{iip, specialize}(f;
33463359
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
33473360
sys = __has_sys(f) ? f.sys : nothing,
33483361
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
3362+
update_initializeprob! = __has_update_initializeprob!(f) ?
3363+
f.update_initializeprob! : nothing,
33493364
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
33503365
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where {
33513366
iip,
@@ -3387,22 +3402,25 @@ function DAEFunction{iip, specialize}(f;
33873402
DAEFunction{iip, specialize, Any, Any, Any,
33883403
Any, Any, Any, Any, Any,
33893404
Any, Any, Any,
3390-
Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
3405+
Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
33913406
vjp, jac_prototype, sparsity,
33923407
Wfact, Wfact_t, paramjac, observed,
3393-
_colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
3408+
_colorvec, sys, initializeprob, update_initializeprob!,
3409+
initializeprobmap, initializeprobpmap)
33943410
else
33953411
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
33963412
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
33973413
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
33983414
typeof(paramjac),
33993415
typeof(observed), typeof(_colorvec),
3400-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
3416+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
3417+
typeof(initializeprobmap),
34013418
typeof(initializeprobpmap)}(
34023419
_f, analytic, tgrad, jac, jvp, vjp,
34033420
jac_prototype, sparsity, Wfact, Wfact_t,
34043421
paramjac, observed,
3405-
_colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
3422+
_colorvec, sys, initializeprob, update_initializeprob!,
3423+
initializeprobmap, initializeprobpmap)
34063424
end
34073425
end
34083426

@@ -4345,6 +4363,7 @@ __has_sys(f) = isdefined(f, :sys)
43454363
__has_analytic_full(f) = isdefined(f, :analytic_full)
43464364
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
43474365
__has_initializeprob(f) = isdefined(f, :initializeprob)
4366+
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
43484367
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
43494368
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
43504369

@@ -4362,6 +4381,9 @@ has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
43624381
function has_initializeprob(f::AbstractSciMLFunction)
43634382
__has_initializeprob(f) && f.initializeprob !== nothing
43644383
end
4384+
function has_update_initializeprob!(f::AbstractSciMLFunction)
4385+
__has_update_initializeprob!(f) && f.update_initializeprob! !== nothing
4386+
end
43654387
function has_initializeprobmap(f::AbstractSciMLFunction)
43664388
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
43674389
end

0 commit comments

Comments
 (0)