Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ numerically-defined functions. See `ModelingToolkit.SplitODEProblem` for
information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O,
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
f1::F1
Expand All @@ -531,6 +531,7 @@ struct SplitFunction{
jvp::JVP
vjp::VJP
jac_prototype::JP
W_prototype::WP
sparsity::SP
Wfact::TW
Wfact_t::TWt
Expand Down Expand Up @@ -1813,9 +1814,9 @@ OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();

## Positional Arguments

- `f(u,p)`: the function to optimize. `u` are the optimization variables and `p` are fixed parameters or data used in the objective,
- `f(u,p)`: the function to optimize. `u` are the optimization variables and `p` are fixed parameters or data used in the objective,
even if no such parameters are used in the objective it should be an argument in the function. For minibatching `p` can be used to pass in
a minibatch, take a look at the tutorial [here](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) to see how to do it.
a minibatch, take a look at the tutorial [here](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) to see how to do it.
This should return a scalar, the loss value, as the return output.
- `adtype`: see the Defining Optimization Functions via AD section below.

Expand Down Expand Up @@ -2649,7 +2650,7 @@ function NonlinearFunction{iip}(f::ODEFunction) where {iip}
end

@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac,
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)
f1 = ODEFunction(f1)
Expand All @@ -2663,13 +2664,13 @@ end
SplitFunction{isinplace(f2), FullSpecialize, typeof(f1), typeof(f2),
typeof(mass_matrix),
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
typeof(vjp), typeof(jac_prototype), typeof(sparsity),
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap),
typeof(initializeprobpmap)}(
f1, f2, mass_matrix,
cache, analytic, tgrad, jac, jvp, vjp,
jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
jac_prototype, W__prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
end
function SplitFunction{iip, specialize}(f1, f2;
Expand All @@ -2685,6 +2686,9 @@ function SplitFunction{iip, specialize}(f1, f2;
jac_prototype = __has_jac_prototype(f1) ?
f1.jac_prototype :
nothing,
W_prototype = __has_W_prototype(f1) ?
f1.W_prototype :
nothing,
sparsity = __has_sparsity(f1) ? f1.sparsity :
jac_prototype,
Wfact = __has_Wfact(f1) ? f1.Wfact : nothing,
Expand Down Expand Up @@ -2713,25 +2717,25 @@ function SplitFunction{iip, specialize}(f1, f2;

if specialize === NoSpecialize
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
analytic,
tgrad, jac, jvp, vjp, jac_prototype,
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap,
initializeprobpmap, initializeprobpmap)
else
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
typeof(_func_cache), typeof(analytic),
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
typeof(jac_prototype), typeof(sparsity),
typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap)}(f1, f2,
mass_matrix, _func_cache, analytic, tgrad, jac,
jvp, vjp, jac_prototype,
jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
end
Expand Down
Loading