Skip to content

Commit 5ca6771

Browse files
feat: support initializeprobpmap in relevant SciMLFunctions
1 parent 012ff4e commit 5ca6771

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

src/scimlfunctions.jl

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ numerically-defined functions.
402402
"""
403403
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404404
O, TCV,
405-
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
405+
SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
406406
f::F
407407
mass_matrix::TMM
408408
analytic::Ta
@@ -421,6 +421,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
421421
sys::SYS
422422
initializeprob::IProb
423423
initializeprobmap::IProbMap
424+
initializeprobpmap::IProbPmap
424425
end
425426

426427
@doc doc"""
@@ -518,7 +519,7 @@ information on generating the SplitFunction from this symbolic engine.
518519
struct SplitFunction{
519520
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
520521
TPJ, O,
521-
TCV, SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
522+
TCV, SYS, IProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
522523
f1::F1
523524
f2::F2
524525
mass_matrix::TMM
@@ -538,6 +539,7 @@ struct SplitFunction{
538539
sys::SYS
539540
initializeprob::IProb
540541
initializeprobmap::IProbMap
542+
initializeprobpmap::IProbPmap
541543
end
542544

543545
@doc doc"""
@@ -1506,7 +1508,7 @@ automatically symbolically generating the Jacobian and more from the
15061508
numerically-defined functions.
15071509
"""
15081510
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
1509-
SYS, IProb, IProbMap} <:
1511+
SYS, IProb, IProbMap, IProbPmap} <:
15101512
AbstractDAEFunction{iip}
15111513
f::F
15121514
analytic::Ta
@@ -1524,6 +1526,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
15241526
sys::SYS
15251527
initializeprob::IProb
15261528
initializeprobmap::IProbMap
1529+
initializeprobpmap::IProbPmap
15271530
end
15281531

15291532
"""
@@ -2415,7 +2418,8 @@ function ODEFunction{iip, specialize}(f;
24152418
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
24162419
sys = __has_sys(f) ? f.sys : nothing,
24172420
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
2418-
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing
2421+
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
2422+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
24192423
) where {iip,
24202424
specialize
24212425
}
@@ -2473,10 +2477,11 @@ function ODEFunction{iip, specialize}(f;
24732477
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
24742478
Any,
24752479
typeof(_colorvec),
2476-
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2480+
typeof(sys), Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
24772481
jvp, vjp, jac_prototype, sparsity, Wfact,
24782482
Wfact_t, W_prototype, paramjac,
2479-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2483+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2484+
initializeprobpmap)
24802485
elseif specialize === false
24812486
ODEFunction{iip, FunctionWrapperSpecialize,
24822487
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2486,10 +2491,12 @@ function ODEFunction{iip, specialize}(f;
24862491
typeof(observed),
24872492
typeof(_colorvec),
24882493
typeof(sys), typeof(initializeprob),
2489-
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
2494+
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
2495+
analytic, tgrad, jac,
24902496
jvp, vjp, jac_prototype, sparsity, Wfact,
24912497
Wfact_t, W_prototype, paramjac,
2492-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2498+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2499+
initializeprobpmap)
24932500
else
24942501
ODEFunction{iip, specialize,
24952502
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2498,11 +2505,12 @@ function ODEFunction{iip, specialize}(f;
24982505
typeof(paramjac),
24992506
typeof(observed),
25002507
typeof(_colorvec),
2501-
typeof(sys), typeof(initializeprob),
2502-
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
2508+
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2509+
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
25032510
jvp, vjp, jac_prototype, sparsity, Wfact,
25042511
Wfact_t, W_prototype, paramjac,
2505-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2512+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2513+
initializeprobpmap)
25062514
end
25072515
end
25082516

@@ -2519,23 +2527,24 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25192527
Any, Any, Any, Any, typeof(f.jac_prototype),
25202528
typeof(f.sparsity), Any, Any, Any,
25212529
Any, typeof(f.colorvec),
2522-
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2530+
typeof(f.sys), Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25232531
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25242532
f.Wfact_t, f.W_prototype, f.paramjac,
2525-
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
2533+
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
2534+
f.initializeprobpmap)
25262535
else
25272536
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25282537
typeof(f.analytic), typeof(f.tgrad),
25292538
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
25302539
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
25312540
typeof(f.paramjac),
25322541
typeof(f.observed), typeof(f.colorvec),
2533-
typeof(f.sys), typeof(f.initializeprob),
2534-
typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2542+
typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprobmap),
2543+
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25352544
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25362545
f.Wfact_t, f.W_prototype, f.paramjac,
25372546
f.observed, f.colorvec, f.sys, f.initializeprob,
2538-
f.initializeprobmap)
2547+
f.initializeprobmap, f.initializeprobpmap)
25392548
end
25402549
end
25412550

@@ -2637,7 +2646,7 @@ end
26372646

26382647
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
26392648
vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac,
2640-
observed, colorvec, sys, initializeprob, initializeprobmap)
2649+
observed, colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
26412650
f1 = ODEFunction(f1)
26422651
f2 = ODEFunction(f2)
26432652

@@ -2651,11 +2660,12 @@ end
26512660
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
26522661
typeof(vjp), typeof(jac_prototype), typeof(sparsity),
26532662
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
2654-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(
2663+
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2664+
typeof(initializeprobpmap)}(
26552665
f1, f2, mass_matrix,
26562666
cache, analytic, tgrad, jac, jvp, vjp,
26572667
jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2658-
initializeprob, initializeprobmap)
2668+
initializeprob, initializeprobmap, initializeprobpmap)
26592669
end
26602670
function SplitFunction{iip, specialize}(f1, f2;
26612671
mass_matrix = __has_mass_matrix(f1) ?
@@ -2685,7 +2695,8 @@ function SplitFunction{iip, specialize}(f1, f2;
26852695
nothing,
26862696
sys = __has_sys(f1) ? f1.sys : nothing,
26872697
initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing,
2688-
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing
2698+
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
2699+
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing
26892700
) where {iip,
26902701
specialize
26912702
}
@@ -2696,23 +2707,25 @@ function SplitFunction{iip, specialize}(f1, f2;
26962707
if specialize === NoSpecialize
26972708
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
26982709
Any, Any, Any, Any, Any,
2699-
Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
2710+
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
27002711
analytic,
27012712
tgrad, jac, jvp, vjp, jac_prototype,
27022713
sparsity, Wfact, Wfact_t, paramjac,
2703-
observed, colorvec, sys, initializeprob, initializeprobmap)
2714+
observed, colorvec, sys, initializeprob, initializeprobmap,
2715+
initializeprobpmap, initializeprobpmap)
27042716
else
27052717
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
27062718
typeof(_func_cache), typeof(analytic),
27072719
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
27082720
typeof(jac_prototype), typeof(sparsity),
27092721
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
27102722
typeof(colorvec),
2711-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(f1, f2,
2723+
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
2724+
typeof(initializeprobpmap)}(f1, f2,
27122725
mass_matrix, _func_cache, analytic, tgrad, jac,
27132726
jvp, vjp, jac_prototype,
27142727
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2715-
initializeprob, initializeprobmap)
2728+
initializeprob, initializeprobmap, initializeprobpmap)
27162729
end
27172730
end
27182731

@@ -3338,7 +3351,8 @@ function DAEFunction{iip, specialize}(f;
33383351
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
33393352
sys = __has_sys(f) ? f.sys : nothing,
33403353
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
3341-
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where {
3354+
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
3355+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where {
33423356
iip,
33433357
specialize
33443358
}
@@ -3378,21 +3392,22 @@ function DAEFunction{iip, specialize}(f;
33783392
DAEFunction{iip, specialize, Any, Any, Any,
33793393
Any, Any, Any, Any, Any,
33803394
Any, Any, Any,
3381-
Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
3395+
Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
33823396
vjp, jac_prototype, sparsity,
33833397
Wfact, Wfact_t, paramjac, observed,
3384-
_colorvec, sys, initializeprob, initializeprobmap)
3398+
_colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
33853399
else
33863400
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
33873401
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
33883402
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
33893403
typeof(paramjac),
33903404
typeof(observed), typeof(_colorvec),
3391-
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(
3405+
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
3406+
typeof(initializeprobpmap)}(
33923407
_f, analytic, tgrad, jac, jvp, vjp,
33933408
jac_prototype, sparsity, Wfact, Wfact_t,
33943409
paramjac, observed,
3395-
_colorvec, sys, initializeprob, initializeprobmap)
3410+
_colorvec, sys, initializeprob, initializeprobmap, initializeprobpmap)
33963411
end
33973412
end
33983413

@@ -4336,6 +4351,7 @@ __has_analytic_full(f) = isdefined(f, :analytic_full)
43364351
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
43374352
__has_initializeprob(f) = isdefined(f, :initializeprob)
43384353
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
4354+
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
43394355

43404356
# compatibility
43414357
has_invW(f::AbstractSciMLFunction) = false
@@ -4354,6 +4370,9 @@ end
43544370
function has_initializeprobmap(f::AbstractSciMLFunction)
43554371
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
43564372
end
4373+
function has_initializeprobpmap(f::AbstractSciMLFunction)
4374+
__has_initializeprobpmap(f) && f.initializeprobpmap !== nothing
4375+
end
43574376

43584377
function has_syms(f::AbstractSciMLFunction)
43594378
if __has_syms(f)

0 commit comments

Comments
 (0)