@@ -518,7 +518,7 @@ numerically-defined functions. See `ModelingToolkit.SplitODEProblem` for
518518information on generating the SplitFunction from this symbolic engine.
519519"""
520520struct SplitFunction{
521- iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
521+ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
522522 TPJ, O,
523523 TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
524524 f1:: F1
@@ -531,6 +531,7 @@ struct SplitFunction{
531531 jvp:: JVP
532532 vjp:: VJP
533533 jac_prototype:: JP
534+ W_prototype:: WP
534535 sparsity:: SP
535536 Wfact:: TW
536537 Wfact_t:: TWt
@@ -1813,9 +1814,9 @@ OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
18131814
18141815## Positional Arguments
18151816
1816- - `f(u,p)`: the function to optimize. `u` are the optimization variables and `p` are fixed parameters or data used in the objective,
1817+ - `f(u,p)`: the function to optimize. `u` are the optimization variables and `p` are fixed parameters or data used in the objective,
18171818even if no such parameters are used in the objective it should be an argument in the function. For minibatching `p` can be used to pass in
1818- a minibatch, take a look at the tutorial [here](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) to see how to do it.
1819+ a minibatch, take a look at the tutorial [here](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) to see how to do it.
18191820This should return a scalar, the loss value, as the return output.
18201821- `adtype`: see the Defining Optimization Functions via AD section below.
18211822
@@ -2649,7 +2650,7 @@ function NonlinearFunction{iip}(f::ODEFunction) where {iip}
26492650end
26502651
26512652@add_kwonly function SplitFunction (f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
2652- vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac,
2653+ vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
26532654 observed, colorvec, sys, initializeprob, update_initializeprob!,
26542655 initializeprobmap, initializeprobpmap)
26552656 f1 = ODEFunction (f1)
@@ -2663,13 +2664,13 @@ end
26632664 SplitFunction{isinplace (f2), FullSpecialize, typeof (f1), typeof (f2),
26642665 typeof (mass_matrix),
26652666 typeof (cache), typeof (analytic), typeof (tgrad), typeof (jac), typeof (jvp),
2666- typeof (vjp), typeof (jac_prototype), typeof (sparsity),
2667+ typeof (vjp), typeof (jac_prototype), typeof (W_prototype), typeof ( sparsity),
26672668 typeof (Wfact), typeof (Wfact_t), typeof (paramjac), typeof (observed), typeof (colorvec),
26682669 typeof (sys), typeof (initializeprob), typeof (update_initializeprob!), typeof (initializeprobmap),
26692670 typeof (initializeprobpmap)}(
26702671 f1, f2, mass_matrix,
26712672 cache, analytic, tgrad, jac, jvp, vjp,
2672- jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2673+ jac_prototype, W__prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
26732674 initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
26742675end
26752676function SplitFunction {iip, specialize} (f1, f2;
@@ -2685,6 +2686,9 @@ function SplitFunction{iip, specialize}(f1, f2;
26852686 jac_prototype = __has_jac_prototype (f1) ?
26862687 f1. jac_prototype :
26872688 nothing ,
2689+ W_prototype = __has_W_prototype (f1) ?
2690+ f1. W_prototype :
2691+ nothing ,
26882692 sparsity = __has_sparsity (f1) ? f1. sparsity :
26892693 jac_prototype,
26902694 Wfact = __has_Wfact (f1) ? f1. Wfact : nothing ,
@@ -2713,25 +2717,25 @@ function SplitFunction{iip, specialize}(f1, f2;
27132717
27142718 if specialize === NoSpecialize
27152719 SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
2716- Any, Any, Any, Any, Any, Any,
2720+ Any, Any, Any, Any, Any, Any, Any,
27172721 Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
27182722 analytic,
2719- tgrad, jac, jvp, vjp, jac_prototype,
2723+ tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
27202724 sparsity, Wfact, Wfact_t, paramjac,
27212725 observed, colorvec, sys, initializeprob. update_initializeprob!, initializeprobmap,
27222726 initializeprobpmap, initializeprobpmap)
27232727 else
27242728 SplitFunction{iip, specialize, typeof (f1), typeof (f2), typeof (mass_matrix),
27252729 typeof (_func_cache), typeof (analytic),
27262730 typeof (tgrad), typeof (jac), typeof (jvp), typeof (vjp),
2727- typeof (jac_prototype), typeof (sparsity),
2731+ typeof (jac_prototype), typeof (W_prototype), typeof ( sparsity),
27282732 typeof (Wfact), typeof (Wfact_t), typeof (paramjac), typeof (observed),
27292733 typeof (colorvec),
27302734 typeof (sys), typeof (initializeprob), typeof (update_initializeprob!),
27312735 typeof (initializeprobmap),
27322736 typeof (initializeprobpmap)}(f1, f2,
27332737 mass_matrix, _func_cache, analytic, tgrad, jac,
2734- jvp, vjp, jac_prototype,
2738+ jvp, vjp, jac_prototype, W_prototype,
27352739 sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
27362740 initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
27372741 end
0 commit comments