Skip to content

Commit 8723590

Browse files
feat: support initializeprobpmap in relevant SciMLFunctions
1 parent 70ae7fb commit 8723590

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

src/scimlfunctions.jl

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ numerically-defined functions.
402402
"""
403403
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404404
O, TCV,
405-
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
405+
SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
406406
f::F
407407
mass_matrix::TMM
408408
analytic::Ta
@@ -421,6 +421,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
421421
sys::SYS
422422
initializeprob::IProb
423423
initializeprobmap::IProbMap
424+
initializeprobpmap::IProbPmap
424425
end
425426

426427
@doc doc"""
@@ -518,7 +519,7 @@ information on generating the SplitFunction from this symbolic engine.
518519
struct SplitFunction{
519520
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
520521
TPJ, O,
521-
TCV, SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
522+
TCV, SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
522523
f1::F1
523524
f2::F2
524525
mass_matrix::TMM
@@ -538,6 +539,7 @@ struct SplitFunction{
538539
sys::SYS
539540
initializeprob::IProb
540541
initializeprobmap::IProbMap
542+
initializeprobpmap::IProbPmap
541543
end
542544

543545
@doc doc"""
@@ -1506,7 +1508,7 @@ automatically symbolically generating the Jacobian and more from the
15061508
numerically-defined functions.
15071509
"""
15081510
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
1509-
SYS, IProb, IProbMap} <:
1511+
SYS, IProb, IProbMap, IProbPmap} <:
15101512
AbstractDAEFunction{iip}
15111513
f::F
15121514
analytic::Ta
@@ -1524,6 +1526,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
15241526
sys::SYS
15251527
initializeprob::IProb
15261528
initializeprobmap::IProbMap
1529+
initializeprobpmap::IProbPmap
15271530
end
15281531

15291532
"""
@@ -2410,7 +2413,8 @@ function ODEFunction{iip, specialize}(f;
24102413
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
24112414
sys = __has_sys(f) ? f.sys : nothing,
24122415
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
2413-
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing
2416+
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
2417+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
24142418
) where {iip,
24152419
specialize
24162420
}
@@ -2468,10 +2472,11 @@ function ODEFunction{iip, specialize}(f;
24682472
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
24692473
Any,
24702474
typeof(_colorvec),
2471-
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2475+
typeof(sys), Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
24722476
jvp, vjp, jac_prototype, sparsity, Wfact,
24732477
Wfact_t, W_prototype, paramjac,
2474-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2478+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2479+
initializeprobpmap)
24752480
elseif specialize === false
24762481
ODEFunction{iip, FunctionWrapperSpecialize,
24772482
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2481,10 +2486,12 @@ function ODEFunction{iip, specialize}(f;
24812486
typeof(observed),
24822487
typeof(_colorvec),
24832488
typeof(sys), typeof(initializeprob),
2484-
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
2489+
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
2490+
analytic, tgrad, jac,
24852491
jvp, vjp, jac_prototype, sparsity, Wfact,
24862492
Wfact_t, W_prototype, paramjac,
2487-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2493+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2494+
initializeprobpmap)
24882495
else
24892496
ODEFunction{iip, specialize,
24902497
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2493,11 +2500,12 @@ function ODEFunction{iip, specialize}(f;
24932500
typeof(paramjac),
24942501
typeof(observed),
24952502
typeof(_colorvec),
2496-
typeof(sys), typeof(initializeprob),
2497-
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
2503+
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2504+
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
24982505
jvp, vjp, jac_prototype, sparsity, Wfact,
24992506
Wfact_t, W_prototype, paramjac,
2500-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2507+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2508+
initializeprobpmap)
25012509
end
25022510
end
25032511

@@ -2514,23 +2522,24 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25142522
Any, Any, Any, Any, typeof(f.jac_prototype),
25152523
typeof(f.sparsity), Any, Any, Any,
25162524
Any, typeof(f.colorvec),
2517-
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2525+
typeof(f.sys), Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25182526
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25192527
f.Wfact_t, f.W_prototype, f.paramjac,
2520-
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
2528+
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
2529+
f.initializeprobpmap)
25212530
else
25222531
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25232532
typeof(f.analytic), typeof(f.tgrad),
25242533
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
25252534
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
25262535
typeof(f.paramjac),
25272536
typeof(f.observed), typeof(f.colorvec),
2528-
typeof(f.sys), typeof(f.initializeprob),
2529-
typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2537+
typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprobmap),
2538+
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25302539
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25312540
f.Wfact_t, f.W_prototype, f.paramjac,
25322541
f.observed, f.colorvec, f.sys, f.initializeprob,
2533-
f.initializeprobmap)
2542+
f.initializeprobmap, f.initializeprobpmap)
25342543
end
25352544
end
25362545

@@ -2632,7 +2641,7 @@ end
26322641

26332642
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
26342643
vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac,
2635-
observed, colorvec, sys, initializeprob, initializeprobmap)
2644+
observed, colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
26362645
f1 = ODEFunction(f1)
26372646
f2 = ODEFunction(f2)
26382647

@@ -2646,11 +2655,12 @@ end
26462655
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
26472656
typeof(vjp), typeof(jac_prototype), typeof(sparsity),
26482657
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
2649-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(
2658+
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2659+
typeof(initializeprobpmap)}(
26502660
f1, f2, mass_matrix,
26512661
cache, analytic, tgrad, jac, jvp, vjp,
26522662
jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2653-
initializeprob, initializeprobmap)
2663+
initializeprob, initializeprobmap, initializeprobpmap)
26542664
end
26552665
function SplitFunction{iip, specialize}(f1, f2;
26562666
mass_matrix = __has_mass_matrix(f1) ?
@@ -2680,7 +2690,8 @@ function SplitFunction{iip, specialize}(f1, f2;
26802690
nothing,
26812691
sys = __has_sys(f1) ? f1.sys : nothing,
26822692
initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing,
2683-
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing
2693+
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
2694+
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing
26842695
) where {iip,
26852696
specialize
26862697
}
@@ -2691,23 +2702,25 @@ function SplitFunction{iip, specialize}(f1, f2;
26912702
if specialize === NoSpecialize
26922703
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
26932704
Any, Any, Any, Any, Any,
2694-
Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
2705+
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
26952706
analytic,
26962707
tgrad, jac, jvp, vjp, jac_prototype,
26972708
sparsity, Wfact, Wfact_t, paramjac,
2698-
observed, colorvec, sys, initializeprob, initializeprobmap)
2709+
observed, colorvec, sys, initializeprob, initializeprobmap,
2710+
initializeprobpmap, initializeprobpmap)
26992711
else
27002712
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
27012713
typeof(_func_cache), typeof(analytic),
27022714
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
27032715
typeof(jac_prototype), typeof(sparsity),
27042716
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
27052717
typeof(colorvec),
2706-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(f1, f2,
2718+
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2719+
typeof(initializeprobpmap)}(f1, f2,
27072720
mass_matrix, _func_cache, analytic, tgrad, jac,
27082721
jvp, vjp, jac_prototype,
27092722
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2710-
initializeprob, initializeprobmap)
2723+
initializeprob, initializeprobmap, initializeprobpmap)
27112724
end
27122725
end
27132726

@@ -3333,7 +3346,8 @@ function DAEFunction{iip, specialize}(f;
33333346
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
33343347
sys = __has_sys(f) ? f.sys : nothing,
33353348
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
3336-
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where {
3349+
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
3350+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where {
33373351
iip,
33383352
specialize
33393353
}
@@ -3373,21 +3387,22 @@ function DAEFunction{iip, specialize}(f;
33733387
DAEFunction{iip, specialize, Any, Any, Any,
33743388
Any, Any, Any, Any, Any,
33753389
Any, Any, Any,
3376-
Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
3390+
Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
33773391
vjp, jac_prototype, sparsity,
33783392
Wfact, Wfact_t, paramjac, observed,
3379-
_colorvec, sys, initializeprob, initializeprobmap)
3393+
_colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
33803394
else
33813395
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
33823396
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
33833397
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
33843398
typeof(paramjac),
33853399
typeof(observed), typeof(_colorvec),
3386-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(
3400+
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
3401+
typeof(initializeprobpmap)}(
33873402
_f, analytic, tgrad, jac, jvp, vjp,
33883403
jac_prototype, sparsity, Wfact, Wfact_t,
33893404
paramjac, observed,
3390-
_colorvec, sys, initializeprob, initializeprobmap)
3405+
_colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
33913406
end
33923407
end
33933408

@@ -4331,6 +4346,7 @@ __has_analytic_full(f) = isdefined(f, :analytic_full)
43314346
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
43324347
__has_initializeprob(f) = isdefined(f, :initializeprob)
43334348
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
4349+
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
43344350

43354351
# compatibility
43364352
has_invW(f::AbstractSciMLFunction) = false
@@ -4349,6 +4365,9 @@ end
43494365
function has_initializeprobmap(f::AbstractSciMLFunction)
43504366
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
43514367
end
4368+
function has_initializeprobpmap(f::AbstractSciMLFunction)
4369+
__has_initializeprobpmap(f) && f.initializeprobpmap !== nothing
4370+
end
43524371

43534372
function has_syms(f::AbstractSciMLFunction)
43544373
if __has_syms(f)

0 commit comments

Comments
 (0)