@@ -518,7 +518,7 @@ numerically-defined functions. See `ModelingToolkit.SplitODEProblem` for
518
518
information on generating the SplitFunction from this symbolic engine.
519
519
"""
520
520
struct SplitFunction{
521
- iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
521
+ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
522
522
TPJ, O,
523
523
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
524
524
f1:: F1
@@ -531,6 +531,7 @@ struct SplitFunction{
531
531
jvp:: JVP
532
532
vjp:: VJP
533
533
jac_prototype:: JP
534
+ W_prototype:: WP
534
535
sparsity:: SP
535
536
Wfact:: TW
536
537
Wfact_t:: TWt
@@ -1813,9 +1814,9 @@ OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
1813
1814
1814
1815
## Positional Arguments
1815
1816
1816
- - `f(u,p)`: the function to optimize. `u` are the optimization variables and `p` are fixed parameters or data used in the objective,
1817
+ - `f(u,p)`: the function to optimize. `u` are the optimization variables and `p` are fixed parameters or data used in the objective,
1817
1818
even if no such parameters are used in the objective it should be an argument in the function. For minibatching `p` can be used to pass in
1818
- a minibatch, take a look at the tutorial [here](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) to see how to do it.
1819
+ a minibatch, take a look at the tutorial [here](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) to see how to do it.
1819
1820
This should return a scalar, the loss value, as the return output.
1820
1821
- `adtype`: see the Defining Optimization Functions via AD section below.
1821
1822
@@ -2649,7 +2650,7 @@ function NonlinearFunction{iip}(f::ODEFunction) where {iip}
2649
2650
end
2650
2651
2651
2652
@add_kwonly function SplitFunction (f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
2652
- vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac,
2653
+ vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
2653
2654
observed, colorvec, sys, initializeprob, update_initializeprob!,
2654
2655
initializeprobmap, initializeprobpmap)
2655
2656
f1 = ODEFunction (f1)
@@ -2663,13 +2664,13 @@ end
2663
2664
SplitFunction{isinplace (f2), FullSpecialize, typeof (f1), typeof (f2),
2664
2665
typeof (mass_matrix),
2665
2666
typeof (cache), typeof (analytic), typeof (tgrad), typeof (jac), typeof (jvp),
2666
- typeof (vjp), typeof (jac_prototype), typeof (sparsity),
2667
+ typeof (vjp), typeof (jac_prototype), typeof (W_prototype), typeof ( sparsity),
2667
2668
typeof (Wfact), typeof (Wfact_t), typeof (paramjac), typeof (observed), typeof (colorvec),
2668
2669
typeof (sys), typeof (initializeprob), typeof (update_initializeprob!), typeof (initializeprobmap),
2669
2670
typeof (initializeprobpmap)}(
2670
2671
f1, f2, mass_matrix,
2671
2672
cache, analytic, tgrad, jac, jvp, vjp,
2672
- jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2673
+ jac_prototype, W__prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2673
2674
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
2674
2675
end
2675
2676
function SplitFunction {iip, specialize} (f1, f2;
@@ -2685,6 +2686,9 @@ function SplitFunction{iip, specialize}(f1, f2;
2685
2686
jac_prototype = __has_jac_prototype (f1) ?
2686
2687
f1. jac_prototype :
2687
2688
nothing ,
2689
+ W_prototype = __has_W_prototype (f1) ?
2690
+ f1. W_prototype :
2691
+ nothing ,
2688
2692
sparsity = __has_sparsity (f1) ? f1. sparsity :
2689
2693
jac_prototype,
2690
2694
Wfact = __has_Wfact (f1) ? f1. Wfact : nothing ,
@@ -2713,25 +2717,25 @@ function SplitFunction{iip, specialize}(f1, f2;
2713
2717
2714
2718
if specialize === NoSpecialize
2715
2719
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
2716
- Any, Any, Any, Any, Any, Any,
2720
+ Any, Any, Any, Any, Any, Any, Any,
2717
2721
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
2718
2722
analytic,
2719
- tgrad, jac, jvp, vjp, jac_prototype,
2723
+ tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
2720
2724
sparsity, Wfact, Wfact_t, paramjac,
2721
2725
observed, colorvec, sys, initializeprob. update_initializeprob!, initializeprobmap,
2722
2726
initializeprobpmap, initializeprobpmap)
2723
2727
else
2724
2728
SplitFunction{iip, specialize, typeof (f1), typeof (f2), typeof (mass_matrix),
2725
2729
typeof (_func_cache), typeof (analytic),
2726
2730
typeof (tgrad), typeof (jac), typeof (jvp), typeof (vjp),
2727
- typeof (jac_prototype), typeof (sparsity),
2731
+ typeof (jac_prototype), typeof (W_prototype), typeof ( sparsity),
2728
2732
typeof (Wfact), typeof (Wfact_t), typeof (paramjac), typeof (observed),
2729
2733
typeof (colorvec),
2730
2734
typeof (sys), typeof (initializeprob), typeof (update_initializeprob!),
2731
2735
typeof (initializeprobmap),
2732
2736
typeof (initializeprobpmap)}(f1, f2,
2733
2737
mass_matrix, _func_cache, analytic, tgrad, jac,
2734
- jvp, vjp, jac_prototype,
2738
+ jvp, vjp, jac_prototype, W_prototype,
2735
2739
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2736
2740
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
2737
2741
end
0 commit comments