diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index e783c5277..c62e8d838 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -455,6 +455,7 @@ SplitFunction{iip,specialize}(f1,f2; jvp = __has_jvp(f1) ? f1.jvp : nothing, vjp = __has_vjp(f1) ? f1.vjp : nothing, jac_prototype = __has_jac_prototype(f1) ? f1.jac_prototype : nothing, + W_prototype = __has_W_prototype(f1) ? f1.W_prototype : nothing, sparsity = __has_sparsity(f1) ? f1.sparsity : jac_prototype, paramjac = __has_paramjac(f1) ? f1.paramjac : nothing, colorvec = __has_colorvec(f1) ? f1.colorvec : nothing, @@ -484,6 +485,11 @@ the usage of the `SplitFunction`. These include: as the prototype and integrators will specialize on this structure where possible. Non-structured sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. The default is `nothing`, which means a dense Jacobian. +- `W_prototype`: a prototype matrix matching the type that matches the W matrix. For example, + if the Jacobian is tridiagonal, and the mass_matrix is diagonal, then an appropriately sized `Tridiagonal` + matrix can be used as the prototype and integrators will specialize on this structure where possible. Non-structured + sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the W matrix. + The default is `nothing`, which means a W of appropriate type for the jacobian and linear solver - `paramjac(pJ,u,p,t)`: returns the parameter Jacobian ``\frac{df_1}{dp}``. - `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity pattern of the `jac_prototype`. This specializes the Jacobian construction when using @@ -3080,7 +3086,7 @@ SDEFunction(f::SDEFunction; kwargs...) = f @add_kwonly function SplitSDEFunction(f1, f2, g, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, - jac_prototype, Wfact, Wfact_t, paramjac, observed, + jac_prototype, W_prototype, Wfact, Wfact_t, paramjac, observed, colorvec, sys) f1 = f1 isa AbstractSciMLOperator ? f1 : SDEFunction(f1) f2 = SDEFunction(f2) @@ -3091,7 +3097,7 @@ SDEFunction(f::SDEFunction; kwargs...) = f typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), typeof(sys)}(f1, f2, mass_matrix, cache, analytic, tgrad, jac, - jac_prototype, Wfact, Wfact_t, paramjac, observed, colorvec, sys) + jac_prototype, W_prototype, Wfact, Wfact_t, paramjac, observed, colorvec, sys) end function SplitSDEFunction{iip, specialize}(f1, f2, g;