Skip to content

Commit 932cb31

Browse files
feat: add update_initializeprob! to relevant SciMLFunctions, remake
1 parent 11f50a7 commit 932cb31

File tree

2 files changed

+63
-32
lines changed

2 files changed

+63
-32
lines changed

src/remake.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function remake(prob::ODEProblem; f = missing,
123123
iip = isinplace(prob)
124124

125125
if f === missing
126-
initializeprob, initializeprobmap, initializeprobpmap = remake_initializeprob(
126+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob(
127127
prob.f.sys, prob.f, u0, tspan[1], p)
128128
if specialization(prob.f) === FunctionWrapperSpecialize
129129
ptspan = promote_tspan(tspan)
@@ -133,14 +133,14 @@ function remake(prob::ODEProblem; f = missing,
133133
unwrapped_f(prob.f.f),
134134
(newu0, newu0, newp,
135135
ptspan[1]));
136-
initializeprob, initializeprobmap, initializeprobpmap)
136+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
137137
else
138138
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
139139
wrapfun_oop(
140140
unwrapped_f(prob.f.f),
141141
(newu0, newp,
142142
ptspan[1]));
143-
initializeprob, initializeprobmap, initializeprobpmap)
143+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
144144
end
145145
else
146146
_f = prob.f
@@ -151,6 +151,13 @@ function remake(prob::ODEProblem; f = missing,
151151
_f = parameterless_type(_f){
152152
iip, specialization(_f), map(typeof, props)...}(props...)
153153
end
154+
if __has_update_initializeprob!(_f)
155+
props = getproperties(_f)
156+
@reset props.update_initializeprob! = update_initializeprob!
157+
props = values(props)
158+
_f = parameterless_type(_f){
159+
iip, specialization(_f), map(typeof, props)...}(props...)
160+
end
154161
if __has_initializeprobmap(_f)
155162
props = getproperties(_f)
156163
@reset props.initializeprobmap = initializeprobmap
@@ -196,18 +203,19 @@ end
196203
197204
Re-create the initialization problem present in the function `scimlfn`, using the
198205
associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and
199-
`p`. By default, returns `nothing, nothing, nothing` if `scimlfn` does not have an
206+
`p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an
200207
initialization problem, and
201-
`scimlfn.initializeprob, scimlfn.initializeprobmap, scimlfn.initializeprobpmap` if it
202-
does.
208+
`scimlfn.initializeprob, scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap`
209+
if it does.
203210
204211
Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
205212
"""
206213
function remake_initializeprob(sys, scimlfn, u0, t0, p)
207214
if !has_initializeprob(scimlfn)
208215
return nothing, nothing, nothing
209216
end
210-
return scimlfn.initializeprob, scimlfn.initializeprobmap, scimlfn.initializeprobpmap
217+
return scimlfn.initializeprob,
218+
scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap
211219
end
212220

213221
"""
@@ -689,7 +697,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
689697
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
690698
end
691699

692-
function updated_u0_p(prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
700+
function updated_u0_p(
701+
prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
693702
if u0 === missing && p === missing
694703
return state_values(prob), parameter_values(prob)
695704
end

src/scimlfunctions.jl

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ numerically-defined functions.
402402
"""
403403
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404404
O, TCV,
405-
SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
405+
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
406406
f::F
407407
mass_matrix::TMM
408408
analytic::Ta
@@ -420,6 +420,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
420420
colorvec::TCV
421421
sys::SYS
422422
initializeprob::IProb
423+
update_initializeprob!::UIProb
423424
initializeprobmap::IProbMap
424425
initializeprobpmap::IProbPmap
425426
end
@@ -519,7 +520,7 @@ information on generating the SplitFunction from this symbolic engine.
519520
struct SplitFunction{
520521
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
521522
TPJ, O,
522-
TCV, SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
523+
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
523524
f1::F1
524525
f2::F2
525526
mass_matrix::TMM
@@ -538,6 +539,7 @@ struct SplitFunction{
538539
colorvec::TCV
539540
sys::SYS
540541
initializeprob::IProb
542+
update_initializeprob!::UIProb
541543
initializeprobmap::IProbMap
542544
initializeprobpmap::IProbPmap
543545
end
@@ -1508,7 +1510,7 @@ automatically symbolically generating the Jacobian and more from the
15081510
numerically-defined functions.
15091511
"""
15101512
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
1511-
SYS, IProb, IProbMap, IProbPmap} <:
1513+
SYS, IProb, UIProb, IProbMap, IProbPmap} <:
15121514
AbstractDAEFunction{iip}
15131515
f::F
15141516
analytic::Ta
@@ -1525,6 +1527,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
15251527
colorvec::TCV
15261528
sys::SYS
15271529
initializeprob::IProb
1530+
update_initializeprob!::UIProb
15281531
initializeprobmap::IProbMap
15291532
initializeprobpmap::IProbPmap
15301533
end
@@ -2418,6 +2421,8 @@ function ODEFunction{iip, specialize}(f;
24182421
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
24192422
sys = __has_sys(f) ? f.sys : nothing,
24202423
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
2424+
update_initializeprob! = __has_update_initializeprob!(f) ?
2425+
f.update_initializeprob! : nothing,
24212426
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
24222427
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
24232428
) where {iip,
@@ -2477,10 +2482,10 @@ function ODEFunction{iip, specialize}(f;
24772482
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
24782483
Any,
24792484
typeof(_colorvec),
2480-
typeof(sys), Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2485+
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
24812486
jvp, vjp, jac_prototype, sparsity, Wfact,
24822487
Wfact_t, W_prototype, paramjac,
2483-
observed, _colorvec, sys, initializeprob, initializeprobmap,
2488+
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
24842489
initializeprobpmap)
24852490
elseif specialize === false
24862491
ODEFunction{iip, FunctionWrapperSpecialize,
@@ -2490,12 +2495,12 @@ function ODEFunction{iip, specialize}(f;
24902495
typeof(paramjac),
24912496
typeof(observed),
24922497
typeof(_colorvec),
2493-
typeof(sys), typeof(initializeprob),
2498+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
24942499
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
24952500
analytic, tgrad, jac,
24962501
jvp, vjp, jac_prototype, sparsity, Wfact,
24972502
Wfact_t, W_prototype, paramjac,
2498-
observed, _colorvec, sys, initializeprob, initializeprobmap,
2503+
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
24992504
initializeprobpmap)
25002505
else
25012506
ODEFunction{iip, specialize,
@@ -2505,11 +2510,12 @@ function ODEFunction{iip, specialize}(f;
25052510
typeof(paramjac),
25062511
typeof(observed),
25072512
typeof(_colorvec),
2508-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2513+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2514+
typeof(initializeprobmap),
25092515
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
25102516
jvp, vjp, jac_prototype, sparsity, Wfact,
25112517
Wfact_t, W_prototype, paramjac,
2512-
observed, _colorvec, sys, initializeprob, initializeprobmap,
2518+
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
25132519
initializeprobpmap)
25142520
end
25152521
end
@@ -2527,10 +2533,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25272533
Any, Any, Any, Any, typeof(f.jac_prototype),
25282534
typeof(f.sparsity), Any, Any, Any,
25292535
Any, typeof(f.colorvec),
2530-
typeof(f.sys), Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2536+
typeof(f.sys), Any, Any, Any, Any}(
2537+
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25312538
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25322539
f.Wfact_t, f.W_prototype, f.paramjac,
2533-
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
2540+
f.observed, f.colorvec, f.sys, f.initializeprob,
2541+
f.update_initializeprob!, f.initializeprobmap,
25342542
f.initializeprobpmap)
25352543
else
25362544
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
@@ -2539,11 +2547,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25392547
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
25402548
typeof(f.paramjac),
25412549
typeof(f.observed), typeof(f.colorvec),
2542-
typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprobmap),
2550+
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
2551+
typeof(f.initializeprobmap),
25432552
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25442553
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25452554
f.Wfact_t, f.W_prototype, f.paramjac,
2546-
f.observed, f.colorvec, f.sys, f.initializeprob,
2555+
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
25472556
f.initializeprobmap, f.initializeprobpmap)
25482557
end
25492558
end
@@ -2646,7 +2655,8 @@ end
26462655

26472656
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
26482657
vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac,
2649-
observed, colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
2658+
observed, colorvec, sys, initializeprob, update_initializeprob!,
2659+
initializeprobmap, initializeprobpmap)
26502660
f1 = ODEFunction(f1)
26512661
f2 = ODEFunction(f2)
26522662

@@ -2660,12 +2670,12 @@ end
26602670
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
26612671
typeof(vjp), typeof(jac_prototype), typeof(sparsity),
26622672
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
2663-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2673+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap),
26642674
typeof(initializeprobpmap)}(
26652675
f1, f2, mass_matrix,
26662676
cache, analytic, tgrad, jac, jvp, vjp,
26672677
jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2668-
initializeprob, initializeprobmap, initializeprobpmap)
2678+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
26692679
end
26702680
function SplitFunction{iip, specialize}(f1, f2;
26712681
mass_matrix = __has_mass_matrix(f1) ?
@@ -2695,6 +2705,8 @@ function SplitFunction{iip, specialize}(f1, f2;
26952705
nothing,
26962706
sys = __has_sys(f1) ? f1.sys : nothing,
26972707
initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing,
2708+
update_initializeprob! = __has_update_initializeprob!(f1) ?
2709+
f1.update_initializeprob! : nothing,
26982710
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
26992711
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing
27002712
) where {iip,
@@ -2706,12 +2718,12 @@ function SplitFunction{iip, specialize}(f1, f2;
27062718

27072719
if specialize === NoSpecialize
27082720
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
2709-
Any, Any, Any, Any, Any,
2721+
Any, Any, Any, Any, Any, Any,
27102722
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
27112723
analytic,
27122724
tgrad, jac, jvp, vjp, jac_prototype,
27132725
sparsity, Wfact, Wfact_t, paramjac,
2714-
observed, colorvec, sys, initializeprob, initializeprobmap,
2726+
observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap,
27152727
initializeprobpmap, initializeprobpmap)
27162728
else
27172729
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
@@ -2720,12 +2732,13 @@ function SplitFunction{iip, specialize}(f1, f2;
27202732
typeof(jac_prototype), typeof(sparsity),
27212733
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
27222734
typeof(colorvec),
2723-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2735+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2736+
typeof(initializeprobmap),
27242737
typeof(initializeprobpmap)}(f1, f2,
27252738
mass_matrix, _func_cache, analytic, tgrad, jac,
27262739
jvp, vjp, jac_prototype,
27272740
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2728-
initializeprob, initializeprobmap, initializeprobpmap)
2741+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
27292742
end
27302743
end
27312744

@@ -3351,6 +3364,8 @@ function DAEFunction{iip, specialize}(f;
33513364
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
33523365
sys = __has_sys(f) ? f.sys : nothing,
33533366
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
3367+
update_initializeprob! = __has_update_initializeprob!(f) ?
3368+
f.update_initializeprob! : nothing,
33543369
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
33553370
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where {
33563371
iip,
@@ -3392,22 +3407,25 @@ function DAEFunction{iip, specialize}(f;
33923407
DAEFunction{iip, specialize, Any, Any, Any,
33933408
Any, Any, Any, Any, Any,
33943409
Any, Any, Any,
3395-
Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
3410+
Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
33963411
vjp, jac_prototype, sparsity,
33973412
Wfact, Wfact_t, paramjac, observed,
3398-
_colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
3413+
_colorvec, sys, initializeprob, update_initializeprob!,
3414+
initializeprobmap, initializeprobpmap)
33993415
else
34003416
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
34013417
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
34023418
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
34033419
typeof(paramjac),
34043420
typeof(observed), typeof(_colorvec),
3405-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
3421+
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
3422+
typeof(initializeprobmap),
34063423
typeof(initializeprobpmap)}(
34073424
_f, analytic, tgrad, jac, jvp, vjp,
34083425
jac_prototype, sparsity, Wfact, Wfact_t,
34093426
paramjac, observed,
3410-
_colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
3427+
_colorvec, sys, initializeprob, update_initializeprob!,
3428+
initializeprobmap, initializeprobpmap)
34113429
end
34123430
end
34133431

@@ -4350,6 +4368,7 @@ __has_sys(f) = isdefined(f, :sys)
43504368
__has_analytic_full(f) = isdefined(f, :analytic_full)
43514369
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
43524370
__has_initializeprob(f) = isdefined(f, :initializeprob)
4371+
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
43534372
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
43544373
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
43554374

@@ -4367,6 +4386,9 @@ has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
43674386
function has_initializeprob(f::AbstractSciMLFunction)
43684387
__has_initializeprob(f) && f.initializeprob !== nothing
43694388
end
4389+
function has_update_initializeprob!(f::AbstractSciMLFunction)
4390+
__has_update_initializeprob!(f) && f.update_initializeprob! !== nothing
4391+
end
43704392
function has_initializeprobmap(f::AbstractSciMLFunction)
43714393
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
43724394
end

0 commit comments

Comments
 (0)