Skip to content

Commit 3b79cc8

Browse files
refactor: use initialization_data instead of initializeprob, etc.
1 parent 7425a86 commit 3b79cc8

File tree

1 file changed

+83
-70
lines changed

1 file changed

+83
-70
lines changed

src/scimlfunctions.jl

Lines changed: 83 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ numerically-defined functions.
402402
"""
403403
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404404
O, TCV,
405-
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
405+
SYS, ID} <: AbstractODEFunction{iip}
406406
f::F
407407
mass_matrix::TMM
408408
analytic::Ta
@@ -419,10 +419,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
419419
observed::O
420420
colorvec::TCV
421421
sys::SYS
422-
initializeprob::IProb
423-
update_initializeprob!::UIProb
424-
initializeprobmap::IProbMap
425-
initializeprobpmap::IProbPmap
422+
initialization_data::ID
426423
end
427424

428425
@doc doc"""
@@ -526,7 +523,7 @@ information on generating the SplitFunction from this symbolic engine.
526523
struct SplitFunction{
527524
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
528525
TPJ, O,
529-
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
526+
TCV, SYS, ID} <: AbstractODEFunction{iip}
530527
f1::F1
531528
f2::F2
532529
mass_matrix::TMM
@@ -545,10 +542,7 @@ struct SplitFunction{
545542
observed::O
546543
colorvec::TCV
547544
sys::SYS
548-
initializeprob::IProb
549-
update_initializeprob!::UIProb
550-
initializeprobmap::IProbMap
551-
initializeprobpmap::IProbPmap
545+
initialization_data::ID
552546
end
553547

554548
@doc doc"""
@@ -1523,7 +1517,7 @@ automatically symbolically generating the Jacobian and more from the
15231517
numerically-defined functions.
15241518
"""
15251519
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
1526-
SYS, IProb, UIProb, IProbMap, IProbPmap} <:
1520+
SYS, ID} <:
15271521
AbstractDAEFunction{iip}
15281522
f::F
15291523
analytic::Ta
@@ -1539,10 +1533,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
15391533
observed::O
15401534
colorvec::TCV
15411535
sys::SYS
1542-
initializeprob::IProb
1543-
update_initializeprob!::UIProb
1544-
initializeprobmap::IProbMap
1545-
initializeprobpmap::IProbPmap
1536+
initialization_data::ID
15461537
end
15471538

15481539
"""
@@ -2432,7 +2423,9 @@ function ODEFunction{iip, specialize}(f;
24322423
update_initializeprob! = __has_update_initializeprob!(f) ?
24332424
f.update_initializeprob! : nothing,
24342425
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
2435-
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
2426+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
2427+
initialization_data = __has_initialization_data(f) ? f.initialization_data :
2428+
nothing
24362429
) where {iip,
24372430
specialize
24382431
}
@@ -2479,8 +2472,11 @@ function ODEFunction{iip, specialize}(f;
24792472
_f = prepare_function(f)
24802473

24812474
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
2475+
initdata = reconstruct_initialization_data(
2476+
initialization_data, initializeprob, update_initializeprob!,
2477+
initializeprobmap, initializeprobpmap)
24822478

2483-
@assert typeof(initializeprob) <:
2479+
@assert typeof(initdata.initializeprob) <:
24842480
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
24852481

24862482
if specialize === NoSpecialize
@@ -2490,11 +2486,10 @@ function ODEFunction{iip, specialize}(f;
24902486
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
24912487
Any,
24922488
typeof(_colorvec),
2493-
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2489+
typeof(sys), Any}(_f, mass_matrix, analytic, tgrad, jac,
24942490
jvp, vjp, jac_prototype, sparsity, Wfact,
24952491
Wfact_t, W_prototype, paramjac,
2496-
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2497-
initializeprobpmap)
2492+
observed, _colorvec, sys, initdata)
24982493
elseif specialize === false
24992494
ODEFunction{iip, FunctionWrapperSpecialize,
25002495
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2503,13 +2498,11 @@ function ODEFunction{iip, specialize}(f;
25032498
typeof(paramjac),
25042499
typeof(observed),
25052500
typeof(_colorvec),
2506-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2507-
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
2501+
typeof(sys), typeof(initdata)}(_f, mass_matrix,
25082502
analytic, tgrad, jac,
25092503
jvp, vjp, jac_prototype, sparsity, Wfact,
25102504
Wfact_t, W_prototype, paramjac,
2511-
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2512-
initializeprobpmap)
2505+
observed, _colorvec, sys, initdata)
25132506
else
25142507
ODEFunction{iip, specialize,
25152508
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2518,13 +2511,10 @@ function ODEFunction{iip, specialize}(f;
25182511
typeof(paramjac),
25192512
typeof(observed),
25202513
typeof(_colorvec),
2521-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2522-
typeof(initializeprobmap),
2523-
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
2514+
typeof(sys), typeof(initdata)}(_f, mass_matrix, analytic, tgrad, jac,
25242515
jvp, vjp, jac_prototype, sparsity, Wfact,
25252516
Wfact_t, W_prototype, paramjac,
2526-
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2527-
initializeprobpmap)
2517+
observed, _colorvec, sys, initdata)
25282518
end
25292519
end
25302520

@@ -2541,27 +2531,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25412531
Any, Any, Any, Any, typeof(f.jac_prototype),
25422532
typeof(f.sparsity), Any, Any, Any,
25432533
Any, typeof(f.colorvec),
2544-
typeof(f.sys), Any, Any, Any, Any}(
2534+
typeof(f.sys), Any}(
25452535
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25462536
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25472537
f.Wfact_t, f.W_prototype, f.paramjac,
2548-
f.observed, f.colorvec, f.sys, f.initializeprob,
2549-
f.update_initializeprob!, f.initializeprobmap,
2550-
f.initializeprobpmap)
2538+
f.observed, f.colorvec, f.sys, f.initialization_data)
25512539
else
25522540
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25532541
typeof(f.analytic), typeof(f.tgrad),
25542542
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
25552543
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
25562544
typeof(f.paramjac),
25572545
typeof(f.observed), typeof(f.colorvec),
2558-
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
2559-
typeof(f.initializeprobmap),
2560-
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2546+
typeof(f.sys), typeof(f.initialization_data)}(
2547+
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25612548
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25622549
f.Wfact_t, f.W_prototype, f.paramjac,
2563-
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
2564-
f.initializeprobmap, f.initializeprobpmap)
2550+
f.observed, f.colorvec, f.sys, f.initialization_data)
25652551
end
25662552
end
25672553

@@ -2692,8 +2678,8 @@ end
26922678

26932679
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
26942680
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
2695-
observed, colorvec, sys, initializeprob, update_initializeprob!,
2696-
initializeprobmap, initializeprobpmap)
2681+
observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing,
2682+
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing)
26972683
f1 = ODEFunction(f1)
26982684
f2 = ODEFunction(f2)
26992685

@@ -2702,17 +2688,20 @@ end
27022688
throw(NonconformingFunctionsError(["f2"]))
27032689
end
27042690

2691+
initdata = reconstruct_initialization_data(
2692+
initialization_data, initializeprob, update_initializeprob!,
2693+
initializeprobmap, initializeprobpmap)
2694+
27052695
SplitFunction{isinplace(f2), FullSpecialize, typeof(f1), typeof(f2),
27062696
typeof(mass_matrix),
27072697
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
27082698
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
27092699
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
2710-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap),
2711-
typeof(initializeprobpmap)}(
2700+
typeof(sys), typeof(initdata)}(
27122701
f1, f2, mass_matrix,
27132702
cache, analytic, tgrad, jac, jvp, vjp,
27142703
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2715-
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
2704+
initdata)
27162705
end
27172706
function SplitFunction{iip, specialize}(f1, f2;
27182707
mass_matrix = __has_mass_matrix(f1) ?
@@ -2748,37 +2737,39 @@ function SplitFunction{iip, specialize}(f1, f2;
27482737
update_initializeprob! = __has_update_initializeprob!(f1) ?
27492738
f1.update_initializeprob! : nothing,
27502739
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
2751-
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing
2740+
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
2741+
initialization_data = __has_initialization_data(f1) ? f1.initialization_data :
2742+
nothing
27522743
) where {iip,
27532744
specialize
27542745
}
27552746
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
2756-
@assert typeof(initializeprob) <:
2747+
initdata = reconstruct_initialization_data(
2748+
initialization_data, initializeprob, update_initializeprob!,
2749+
initializeprobmap, initializeprobpmap)
2750+
@assert typeof(initdata.initializeprob) <:
27572751
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
27582752

27592753
if specialize === NoSpecialize
27602754
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
27612755
Any, Any, Any, Any, Any, Any, Any,
2762-
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
2756+
Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
27632757
analytic,
27642758
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
27652759
sparsity, Wfact, Wfact_t, paramjac,
2766-
observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap,
2767-
initializeprobpmap, initializeprobpmap)
2760+
observed, colorvec, sys, initdata)
27682761
else
27692762
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
27702763
typeof(_func_cache), typeof(analytic),
27712764
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
27722765
typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
27732766
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
27742767
typeof(colorvec),
2775-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2776-
typeof(initializeprobmap),
2777-
typeof(initializeprobpmap)}(f1, f2,
2768+
typeof(sys), typeof(initdata)}(f1, f2,
27782769
mass_matrix, _func_cache, analytic, tgrad, jac,
27792770
jvp, vjp, jac_prototype, W_prototype,
27802771
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2781-
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
2772+
initdata)
27822773
end
27832774
end
27842775

@@ -3407,7 +3398,9 @@ function DAEFunction{iip, specialize}(f;
34073398
update_initializeprob! = __has_update_initializeprob!(f) ?
34083399
f.update_initializeprob! : nothing,
34093400
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
3410-
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where {
3401+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
3402+
initialization_data = __has_initialization_data(f) ? f.initialization_data :
3403+
nothing) where {
34113404
iip,
34123405
specialize
34133406
}
@@ -3439,33 +3432,32 @@ function DAEFunction{iip, specialize}(f;
34393432

34403433
_f = prepare_function(f)
34413434
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
3435+
initdata = reconstruct_initialization_data(
3436+
initialization_data, initializeprob, update_initializeprob!,
3437+
initializeprobmap, initializeprobpmap)
34423438

3443-
@assert typeof(initializeprob) <:
3439+
@assert typeof(initdata.initializeprob) <:
34443440
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
34453441

34463442
if specialize === NoSpecialize
34473443
DAEFunction{iip, specialize, Any, Any, Any,
34483444
Any, Any, Any, Any, Any,
34493445
Any, Any, Any,
3450-
Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
3446+
Any, typeof(_colorvec), Any, Any}(_f, analytic, tgrad, jac, jvp,
34513447
vjp, jac_prototype, sparsity,
34523448
Wfact, Wfact_t, paramjac, observed,
3453-
_colorvec, sys, initializeprob, update_initializeprob!,
3454-
initializeprobmap, initializeprobpmap)
3449+
_colorvec, sys, initdata)
34553450
else
34563451
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
34573452
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
34583453
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
34593454
typeof(paramjac),
34603455
typeof(observed), typeof(_colorvec),
3461-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
3462-
typeof(initializeprobmap),
3463-
typeof(initializeprobpmap)}(
3456+
typeof(sys), typeof(initdata)}(
34643457
_f, analytic, tgrad, jac, jvp, vjp,
34653458
jac_prototype, sparsity, Wfact, Wfact_t,
34663459
paramjac, observed,
3467-
_colorvec, sys, initializeprob, update_initializeprob!,
3468-
initializeprobmap, initializeprobpmap)
3460+
_colorvec, sys, initdata)
34693461
end
34703462
end
34713463

@@ -4384,6 +4376,14 @@ function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
43844376
return sys
43854377
end
43864378

4379+
function reconstruct_initialization_data(
4380+
initdata, initprob, update_initprob!, initprobmap, initprobpmap)
4381+
if initdata === nothing && initprob !== nothing
4382+
initdata = InitializationData(initprob, update_initprob!, initprobmap, initprobpmap)
4383+
end
4384+
return initprob
4385+
end
4386+
43874387
########## Existence Functions
43884388

43894389
# Check that field/property exists (may be nothing)
@@ -4407,10 +4407,19 @@ __has_colorvec(f) = isdefined(f, :colorvec)
44074407
__has_sys(f) = isdefined(f, :sys)
44084408
__has_analytic_full(f) = isdefined(f, :analytic_full)
44094409
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
4410-
__has_initializeprob(f) = isdefined(f, :initializeprob)
4411-
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
4412-
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
4413-
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
4410+
function __has_initializeprob(f)
4411+
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob)
4412+
end
4413+
function __has_update_initializeprob!(f)
4414+
has_initialization_data(f) && isdefined(f.initialization_data, :update_initializeprob!)
4415+
end
4416+
function __has_initializeprobmap(f)
4417+
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobmap)
4418+
end
4419+
function __has_initializeprobpmap(f)
4420+
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobpmap)
4421+
end
4422+
__has_initialization_data(f) = isdefined(f, :initialization_data)
44144423

44154424
# compatibility
44164425
has_invW(f::AbstractSciMLFunction) = false
@@ -4424,16 +4433,20 @@ has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothin
44244433
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
44254434
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
44264435
function has_initializeprob(f::AbstractSciMLFunction)
4427-
__has_initializeprob(f) && f.initializeprob !== nothing
4436+
__has_initializeprob(f) && f.initialization_data.initializeprob !== nothing
44284437
end
44294438
function has_update_initializeprob!(f::AbstractSciMLFunction)
4430-
__has_update_initializeprob!(f) && f.update_initializeprob! !== nothing
4439+
__has_update_initializeprob!(f) &&
4440+
f.initialization_data.update_initializeprob! !== nothing
44314441
end
44324442
function has_initializeprobmap(f::AbstractSciMLFunction)
4433-
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
4443+
__has_initializeprobmap(f) && f.initialization_data.initializeprobmap !== nothing
44344444
end
44354445
function has_initializeprobpmap(f::AbstractSciMLFunction)
4436-
__has_initializeprobpmap(f) && f.initializeprobpmap !== nothing
4446+
__has_initializeprobpmap(f) && f.initialization_data.initializeprobpmap !== nothing
4447+
end
4448+
function has_initialization_data(f::AbstractSciMLFunction)
4449+
__has_initialization_data(f) && f.initialization_data !== nothing
44374450
end
44384451

44394452
function has_syms(f::AbstractSciMLFunction)

0 commit comments

Comments
 (0)