Skip to content

Commit 19d46e6

Browse files
refactor: use initialization_data instead of initializeprob, etc.
1 parent de94f4e commit 19d46e6

File tree

1 file changed

+85
-77
lines changed

1 file changed

+85
-77
lines changed

src/scimlfunctions.jl

Lines changed: 85 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,8 @@ automatically symbolically generating the Jacobian and more from the
405405
numerically-defined functions.
406406
"""
407407
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
408-
O, TCV, SYS,
409-
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
408+
O, TCV,
409+
SYS, ID, NLP} <: AbstractODEFunction{iip}
410410
f::F
411411
mass_matrix::TMM
412412
analytic::Ta
@@ -423,10 +423,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
423423
observed::O
424424
colorvec::TCV
425425
sys::SYS
426-
initializeprob::IProb
427-
update_initializeprob!::UIProb
428-
initializeprobmap::IProbMap
429-
initializeprobpmap::IProbPmap
426+
initialization_data::ID
430427
nlprob::NLP
431428
end
432429

@@ -530,8 +527,8 @@ information on generating the SplitFunction from this symbolic engine.
530527
"""
531528
struct SplitFunction{
532529
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
533-
TPJ, O, TCV, SYS,
534-
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
530+
TPJ, O,
531+
TCV, SYS, ID, NLP} <: AbstractODEFunction{iip}
535532
f1::F1
536533
f2::F2
537534
mass_matrix::TMM
@@ -550,10 +547,7 @@ struct SplitFunction{
550547
observed::O
551548
colorvec::TCV
552549
sys::SYS
553-
initializeprob::IProb
554-
update_initializeprob!::UIProb
555-
initializeprobmap::IProbMap
556-
initializeprobpmap::IProbPmap
550+
initialization_data::ID
557551
nlprob::NLP
558552
end
559553

@@ -1529,7 +1523,7 @@ automatically symbolically generating the Jacobian and more from the
15291523
numerically-defined functions.
15301524
"""
15311525
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
1532-
SYS, IProb, UIProb, IProbMap, IProbPmap} <:
1526+
SYS, ID} <:
15331527
AbstractDAEFunction{iip}
15341528
f::F
15351529
analytic::Ta
@@ -1545,10 +1539,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
15451539
observed::O
15461540
colorvec::TCV
15471541
sys::SYS
1548-
initializeprob::IProb
1549-
update_initializeprob!::UIProb
1550-
initializeprobmap::IProbMap
1551-
initializeprobpmap::IProbPmap
1542+
initialization_data::ID
15521543
end
15531544

15541545
"""
@@ -2440,6 +2431,8 @@ function ODEFunction{iip, specialize}(f;
24402431
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
24412432
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
24422433
nlprob = __has_nlprob(f) ? f.nlprob : nothing,
2434+
initialization_data = __has_initialization_data(f) ? f.initialization_data :
2435+
nothing
24432436
) where {iip,
24442437
specialize
24452438
}
@@ -2486,8 +2479,11 @@ function ODEFunction{iip, specialize}(f;
24862479
_f = prepare_function(f)
24872480

24882481
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
2482+
initdata = reconstruct_initialization_data(
2483+
initialization_data, initializeprob, update_initializeprob!,
2484+
initializeprobmap, initializeprobpmap)
24892485

2490-
@assert typeof(initializeprob) <:
2486+
@assert typeof(initdata.initializeprob) <:
24912487
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
24922488

24932489
if specialize === NoSpecialize
@@ -2497,11 +2493,10 @@ function ODEFunction{iip, specialize}(f;
24972493
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
24982494
Any,
24992495
typeof(_colorvec),
2500-
typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2496+
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
25012497
jvp, vjp, jac_prototype, sparsity, Wfact,
25022498
Wfact_t, W_prototype, paramjac,
2503-
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2504-
initializeprobpmap, nlprob)
2499+
observed, _colorvec, sys, initdata, nlprob)
25052500
elseif specialize === false
25062501
ODEFunction{iip, FunctionWrapperSpecialize,
25072502
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2510,16 +2505,11 @@ function ODEFunction{iip, specialize}(f;
25102505
typeof(paramjac),
25112506
typeof(observed),
25122507
typeof(_colorvec),
2513-
typeof(sys), typeof(initializeprob),
2514-
typeof(update_initializeprob!),
2515-
typeof(initializeprobmap),
2516-
typeof(initializeprobpmap),
2517-
typeof(nlprob)}(_f, mass_matrix,
2508+
typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix,
25182509
analytic, tgrad, jac,
25192510
jvp, vjp, jac_prototype, sparsity, Wfact,
25202511
Wfact_t, W_prototype, paramjac,
2521-
observed, _colorvec, sys, initializeprob, update_initializeprob!,
2522-
initializeprobmap, initializeprobpmap, nlprob)
2512+
observed, _colorvec, sys, initdata, nlprob)
25232513
else
25242514
ODEFunction{iip, specialize,
25252515
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2528,14 +2518,10 @@ function ODEFunction{iip, specialize}(f;
25282518
typeof(paramjac),
25292519
typeof(observed),
25302520
typeof(_colorvec),
2531-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2532-
typeof(initializeprobmap),
2533-
typeof(initializeprobpmap),
2534-
typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, jac,
2535-
jvp, vjp, jac_prototype, sparsity, Wfact,
2521+
typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix, analytic, tgrad,
2522+
jac, jvp, vjp, jac_prototype, sparsity, Wfact,
25362523
Wfact_t, W_prototype, paramjac,
2537-
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2538-
initializeprobpmap, nlprob)
2524+
observed, _colorvec, sys, initdata, nlprob)
25392525
end
25402526
end
25412527

@@ -2552,28 +2538,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25522538
Any, Any, Any, Any, typeof(f.jac_prototype),
25532539
typeof(f.sparsity), Any, Any, Any,
25542540
Any, typeof(f.colorvec),
2555-
typeof(f.sys), Any, Any, Any, Any, Any}(
2541+
typeof(f.sys), Any, Any}(
25562542
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25572543
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25582544
f.Wfact_t, f.W_prototype, f.paramjac,
2559-
f.observed, f.colorvec, f.sys, f.initializeprob,
2560-
f.update_initializeprob!, f.initializeprobmap,
2561-
f.initializeprobpmap, f.nlprob)
2545+
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
25622546
else
25632547
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25642548
typeof(f.analytic), typeof(f.tgrad),
25652549
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
25662550
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
25672551
typeof(f.paramjac),
25682552
typeof(f.observed), typeof(f.colorvec),
2569-
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
2570-
typeof(f.initializeprobmap),
2571-
typeof(f.initializeprobpmap),
2572-
typeof(f.nlprob)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2553+
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob)}(
2554+
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25732555
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25742556
f.Wfact_t, f.W_prototype, f.paramjac,
2575-
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
2576-
f.initializeprobmap, f.initializeprobpmap, f.nlprob)
2557+
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
25772558
end
25782559
end
25792560

@@ -2704,8 +2685,8 @@ end
27042685

27052686
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
27062687
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
2707-
observed, colorvec, sys, initializeprob, update_initializeprob!,
2708-
initializeprobmap, initializeprobpmap, nlprob)
2688+
observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing,
2689+
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob)
27092690
f1 = ODEFunction(f1)
27102691
f2 = ODEFunction(f2)
27112692

@@ -2714,17 +2695,20 @@ end
27142695
throw(NonconformingFunctionsError(["f2"]))
27152696
end
27162697

2698+
initdata = reconstruct_initialization_data(
2699+
initialization_data, initializeprob, update_initializeprob!,
2700+
initializeprobmap, initializeprobpmap)
2701+
27172702
SplitFunction{isinplace(f2), FullSpecialize, typeof(f1), typeof(f2),
27182703
typeof(mass_matrix),
27192704
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
27202705
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
27212706
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
2722-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap),
2723-
typeof(initializeprobpmap), typeof(nlprob)}(
2707+
typeof(sys), typeof(initdata), typeof(nlprob)}(
27242708
f1, f2, mass_matrix,
27252709
cache, analytic, tgrad, jac, jvp, vjp,
27262710
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2727-
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
2711+
initdata, nlprob)
27282712
end
27292713
function SplitFunction{iip, specialize}(f1, f2;
27302714
mass_matrix = __has_mass_matrix(f1) ?
@@ -2761,37 +2745,39 @@ function SplitFunction{iip, specialize}(f1, f2;
27612745
f1.update_initializeprob! : nothing,
27622746
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
27632747
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
2764-
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing
2748+
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing,
2749+
initialization_data = __has_initialization_data(f1) ? f1.initialization_data :
2750+
nothing
27652751
) where {iip,
27662752
specialize
27672753
}
27682754
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
2769-
@assert typeof(initializeprob) <:
2755+
initdata = reconstruct_initialization_data(
2756+
initialization_data, initializeprob, update_initializeprob!,
2757+
initializeprobmap, initializeprobpmap)
2758+
@assert typeof(initdata.initializeprob) <:
27702759
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
27712760

27722761
if specialize === NoSpecialize
27732762
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
27742763
Any, Any, Any, Any, Any, Any, Any,
2775-
Any, Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
2764+
Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
27762765
analytic,
27772766
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
27782767
sparsity, Wfact, Wfact_t, paramjac,
2779-
observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap,
2780-
initializeprobpmap, initializeprobpmap, nlprob)
2768+
observed, colorvec, sys, initdata, nlprob)
27812769
else
27822770
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
27832771
typeof(_func_cache), typeof(analytic),
27842772
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
27852773
typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
27862774
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
27872775
typeof(colorvec),
2788-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2789-
typeof(initializeprobmap),
2790-
typeof(initializeprobpmap), typeof(nlprob)}(f1, f2,
2776+
typeof(sys), typeof(initdata), typeof(nlprob)}(f1, f2,
27912777
mass_matrix, _func_cache, analytic, tgrad, jac,
27922778
jvp, vjp, jac_prototype, W_prototype,
27932779
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2794-
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
2780+
initdata, nlprob)
27952781
end
27962782
end
27972783

@@ -3420,7 +3406,9 @@ function DAEFunction{iip, specialize}(f;
34203406
update_initializeprob! = __has_update_initializeprob!(f) ?
34213407
f.update_initializeprob! : nothing,
34223408
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
3423-
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where {
3409+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
3410+
initialization_data = __has_initialization_data(f) ? f.initialization_data :
3411+
nothing) where {
34243412
iip,
34253413
specialize
34263414
}
@@ -3452,33 +3440,32 @@ function DAEFunction{iip, specialize}(f;
34523440

34533441
_f = prepare_function(f)
34543442
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
3443+
initdata = reconstruct_initialization_data(
3444+
initialization_data, initializeprob, update_initializeprob!,
3445+
initializeprobmap, initializeprobpmap)
34553446

3456-
@assert typeof(initializeprob) <:
3447+
@assert typeof(initdata.initializeprob) <:
34573448
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
34583449

34593450
if specialize === NoSpecialize
34603451
DAEFunction{iip, specialize, Any, Any, Any,
34613452
Any, Any, Any, Any, Any,
34623453
Any, Any, Any,
3463-
Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
3454+
Any, typeof(_colorvec), Any, Any}(_f, analytic, tgrad, jac, jvp,
34643455
vjp, jac_prototype, sparsity,
34653456
Wfact, Wfact_t, paramjac, observed,
3466-
_colorvec, sys, initializeprob, update_initializeprob!,
3467-
initializeprobmap, initializeprobpmap)
3457+
_colorvec, sys, initdata)
34683458
else
34693459
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
34703460
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
34713461
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
34723462
typeof(paramjac),
34733463
typeof(observed), typeof(_colorvec),
3474-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
3475-
typeof(initializeprobmap),
3476-
typeof(initializeprobpmap)}(
3464+
typeof(sys), typeof(initdata)}(
34773465
_f, analytic, tgrad, jac, jvp, vjp,
34783466
jac_prototype, sparsity, Wfact, Wfact_t,
34793467
paramjac, observed,
3480-
_colorvec, sys, initializeprob, update_initializeprob!,
3481-
initializeprobmap, initializeprobpmap)
3468+
_colorvec, sys, initdata)
34823469
end
34833470
end
34843471

@@ -4397,6 +4384,14 @@ function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
43974384
return sys
43984385
end
43994386

4387+
function reconstruct_initialization_data(
4388+
initdata, initprob, update_initprob!, initprobmap, initprobpmap)
4389+
if initdata === nothing && initprob !== nothing
4390+
initdata = OverrideInitData(initprob, update_initprob!, initprobmap, initprobpmap)
4391+
end
4392+
return initdata
4393+
end
4394+
44004395
########## Existence Functions
44014396

44024397
# Check that field/property exists (may be nothing)
@@ -4420,11 +4415,20 @@ __has_colorvec(f) = isdefined(f, :colorvec)
44204415
__has_sys(f) = isdefined(f, :sys)
44214416
__has_analytic_full(f) = isdefined(f, :analytic_full)
44224417
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
4423-
__has_initializeprob(f) = isdefined(f, :initializeprob)
4424-
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
4425-
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
4426-
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
44274418
__has_nlprob(f) = isdefined(f, :nlprob)
4419+
function __has_initializeprob(f)
4420+
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob)
4421+
end
4422+
function __has_update_initializeprob!(f)
4423+
has_initialization_data(f) && isdefined(f.initialization_data, :update_initializeprob!)
4424+
end
4425+
function __has_initializeprobmap(f)
4426+
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobmap)
4427+
end
4428+
function __has_initializeprobpmap(f)
4429+
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobpmap)
4430+
end
4431+
__has_initialization_data(f) = isdefined(f, :initialization_data)
44284432

44294433
# compatibility
44304434
has_invW(f::AbstractSciMLFunction) = false
@@ -4438,16 +4442,20 @@ has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothin
44384442
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
44394443
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
44404444
function has_initializeprob(f::AbstractSciMLFunction)
4441-
__has_initializeprob(f) && f.initializeprob !== nothing
4445+
__has_initializeprob(f) && f.initialization_data.initializeprob !== nothing
44424446
end
44434447
function has_update_initializeprob!(f::AbstractSciMLFunction)
4444-
__has_update_initializeprob!(f) && f.update_initializeprob! !== nothing
4448+
__has_update_initializeprob!(f) &&
4449+
f.initialization_data.update_initializeprob! !== nothing
44454450
end
44464451
function has_initializeprobmap(f::AbstractSciMLFunction)
4447-
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
4452+
__has_initializeprobmap(f) && f.initialization_data.initializeprobmap !== nothing
44484453
end
44494454
function has_initializeprobpmap(f::AbstractSciMLFunction)
4450-
__has_initializeprobpmap(f) && f.initializeprobpmap !== nothing
4455+
__has_initializeprobpmap(f) && f.initialization_data.initializeprobpmap !== nothing
4456+
end
4457+
function has_initialization_data(f::AbstractSciMLFunction)
4458+
__has_initialization_data(f) && f.initialization_data !== nothing
44514459
end
44524460

44534461
function has_syms(f::AbstractSciMLFunction)

0 commit comments

Comments
 (0)