@@ -2094,6 +2094,109 @@ struct MultiObjectiveOptimizationFunction{
20942094 initialization_data:: ID
20952095end
20962096
2097+ """
2098+ $(TYPEDEF)
2099+ """
2100+ abstract type AbstractODEInputFunction{iip} <: AbstractDiffEqFunction{iip} end
2101+
2102+ @doc doc"""
2103+ $(TYPEDEF)
2104+
2105+ A representation of a ODE function `f` with inputs, defined by:
2106+
2107+ ```math
2108+ \f rac{dx}{dt} = f(x, u, p, t)
2109+ ```
2110+ where `x` are the states of the system and `u` are the inputs (which may represent
2111+ different things in different contexts, such as control variables in optimal control).
2112+
2113+ Includes all of its related functions, such as the Jacobian of `f`, its gradient
2114+ with respect to time, and more. For all cases, `u0` is the initial condition,
2115+ `p` are the parameters, and `t` is the independent variable.
2116+
2117+ ```julia
2118+ ODEInputFunction{iip, specialize}(f;
2119+ mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
2120+ analytic = __has_analytic(f) ? f.analytic : nothing,
2121+ tgrad= __has_tgrad(f) ? f.tgrad : nothing,
2122+ jac = __has_jac(f) ? f.jac : nothing,
2123+ control_jac = __has_controljac(f) ? f.controljac : nothing,
2124+ jvp = __has_jvp(f) ? f.jvp : nothing,
2125+ vjp = __has_vjp(f) ? f.vjp : nothing,
2126+ jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
2127+ controljac_prototype = __has_controljac_prototype(f) ? f.controljac_prototype : nothing,
2128+ sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
2129+ paramjac = __has_paramjac(f) ? f.paramjac : nothing,
2130+ syms = nothing,
2131+ indepsym = nothing,
2132+ paramsyms = nothing,
2133+ colorvec = __has_colorvec(f) ? f.colorvec : nothing,
2134+ sys = __has_sys(f) ? f.sys : nothing)
2135+ ```
2136+
2137+ `f` should be given as `f(x_out,x,u,p,t)` or `out = f(x,u,p,t)`.
2138+ See the section on `iip` for more details on in-place vs out-of-place handling.
2139+
2140+ - `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used
2141+ to determine that the equation is actually a BVP for differential algebraic equation (DAE)
2142+ if `M` is singular.
2143+ - `jac(J,dx,x,u,p,gamma,t)` or `J=jac(dx,x,u,p,gamma,t)`: returns ``\f rac{df}{dx}``
2144+ - `control_jac(J,du,x,u,p,gamma,t)` or `J=control_jac(du,x,u,p,gamma,t)`: returns ``\f rac{df}{du}``
2145+ - `jvp(Jv,v,du,x,u,p,gamma,t)` or `Jv=jvp(v,du,x,u,p,gamma,t)`: returns the directional
2146+ derivative ``\f rac{df}{du} v``
2147+ - `vjp(Jv,v,du,x,u,p,gamma,t)` or `Jv=vjp(v,du,x,u,p,gamma,t)`: returns the adjoint
2148+ derivative ``\f rac{df}{du}^\a st v``
2149+ - `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
2150+ if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
2151+ as the prototype and integrators will specialize on this structure where possible. Non-structured
2152+ sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
2153+ The default is `nothing`, which means a dense Jacobian.
2154+ - `controljac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
2155+ if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
2156+ as the prototype and integrators will specialize on this structure where possible. Non-structured
2157+ sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
2158+ The default is `nothing`, which means a dense Jacobian.
2159+ - `paramjac(pJ,x,u,p,t)`: returns the parameter Jacobian ``\f rac{df}{dp}``.
2160+ - `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity
2161+ pattern of the `jac_prototype`. This specializes the Jacobian construction when using
2162+ finite differences and automatic differentiation to be computed in an accelerated manner
2163+ based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
2164+ internally computed on demand when required. The cost of this operation is highly dependent
2165+ on the sparsity pattern.
2166+
2167+ ## iip: In-Place vs Out-Of-Place
2168+ For more details on this argument, see the ODEFunction documentation.
2169+
2170+ ## specialize: Controlling Compilation and Specialization
2171+ For more details on this argument, see the ODEFunction documentation.
2172+
2173+ ## Fields
2174+ The fields of the ODEInputFunction type directly match the names of the inputs.
2175+ """
2176+ struct ODEInputFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
2177+ JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV,
2178+ SYS, ID} <: AbstractODEInputFunction{iip}
2179+ f:: F
2180+ mass_matrix:: TMM
2181+ analytic:: Ta
2182+ tgrad:: Tt
2183+ jac:: TJ
2184+ controljac:: CTJ
2185+ jvp:: JVP
2186+ vjp:: VJP
2187+ jac_prototype:: JP
2188+ controljac_prototype:: CJP
2189+ sparsity:: SP
2190+ Wfact:: TW
2191+ Wfact_t:: TWt
2192+ W_prototype:: WP
2193+ paramjac:: TPJ
2194+ observed:: O
2195+ colorvec:: TCV
2196+ sys:: SYS
2197+ initialization_data:: ID
2198+ end
2199+
20972200"""
20982201$(TYPEDEF)
20992202"""
@@ -2493,6 +2596,7 @@ end
24932596(f:: ImplicitDiscreteFunction )(args... ) = f. f (args... )
24942597(f:: DAEFunction )(args... ) = f. f (args... )
24952598(f:: DDEFunction )(args... ) = f. f (args... )
2599+ (f:: ODEInputFunction )(args... ) = f. f (args... )
24962600
24972601function (f:: DynamicalDDEFunction )(u, h, p, t)
24982602 ArrayPartition (f. f1 (u. x[1 ], u. x[2 ], h, p, t), f. f2 (u. x[1 ], u. x[2 ], h, p, t))
@@ -4595,6 +4699,149 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...)
45954699 BatchIntegralFunction {calculated_iip} (f, integrand_prototype; kwargs... )
45964700end
45974701
4702+ function ODEInputFunction {iip, specialize} (f;
4703+ mass_matrix = __has_mass_matrix (f) ? f. mass_matrix :
4704+ I,
4705+ analytic = __has_analytic (f) ? f. analytic : nothing ,
4706+ tgrad = __has_tgrad (f) ? f. tgrad : nothing ,
4707+ jac = __has_jac (f) ? f. jac : nothing ,
4708+ controljac = __has_controljac (f) ? f. controljac : nothing ,
4709+ jvp = __has_jvp (f) ? f. jvp : nothing ,
4710+ vjp = __has_vjp (f) ? f. vjp : nothing ,
4711+ jac_prototype = __has_jac_prototype (f) ?
4712+ f. jac_prototype :
4713+ nothing ,
4714+ controljac_prototype = __has_controljac_prototype (f) ?
4715+ f. controljac_prototype :
4716+ nothing ,
4717+ sparsity = __has_sparsity (f) ? f. sparsity :
4718+ jac_prototype,
4719+ Wfact = __has_Wfact (f) ? f. Wfact : nothing ,
4720+ Wfact_t = __has_Wfact_t (f) ? f. Wfact_t : nothing ,
4721+ W_prototype = __has_W_prototype (f) ? f. W_prototype : nothing ,
4722+ paramjac = __has_paramjac (f) ? f. paramjac : nothing ,
4723+ syms = nothing ,
4724+ indepsym = nothing ,
4725+ paramsyms = nothing ,
4726+ observed = __has_observed (f) ? f. observed :
4727+ DEFAULT_OBSERVED,
4728+ colorvec = __has_colorvec (f) ? f. colorvec : nothing ,
4729+ sys = __has_sys (f) ? f. sys : nothing ,
4730+ initializeprob = __has_initializeprob (f) ? f. initializeprob : nothing ,
4731+ update_initializeprob! = __has_update_initializeprob! (f) ?
4732+ f. update_initializeprob! : nothing ,
4733+ initializeprobmap = __has_initializeprobmap (f) ? f. initializeprobmap : nothing ,
4734+ initializeprobpmap = __has_initializeprobpmap (f) ? f. initializeprobpmap : nothing ,
4735+ initialization_data = __has_initialization_data (f) ? f. initialization_data :
4736+ nothing ,
4737+ nlprob_data = __has_nlprob_data (f) ? f. nlprob_data : nothing
4738+ ) where {iip,
4739+ specialize
4740+ }
4741+ if mass_matrix === I && f isa Tuple
4742+ mass_matrix = ((I for i in 1 : length (f)). .. ,)
4743+ end
4744+
4745+ if (specialize === FunctionWrapperSpecialize) &&
4746+ ! (f isa FunctionWrappersWrappers. FunctionWrappersWrapper)
4747+ error (" FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!" )
4748+ end
4749+
4750+ if jac === nothing && isa (jac_prototype, AbstractSciMLOperator)
4751+ if iip
4752+ jac = (J, x, u, p, t) -> update_coefficients! (J, x, p, t) # (J,x,u,p,t)
4753+ else
4754+ jac = (x, u, p, t) -> update_coefficients (deepcopy (jac_prototype), x, p, t)
4755+ end
4756+ end
4757+
4758+ if controljac === nothing && isa (controljac_prototype, AbstractSciMLOperator)
4759+ if iip_bc
4760+ controljac = (J, x, u, p, t) -> update_coefficients! (J, u, p, t) # (J,x,u,p,t)
4761+ else
4762+ controljac = (x, u, p, t) -> update_coefficients (deepcopy (controljac_prototype), u, p, t)
4763+ end
4764+ end
4765+
4766+ if jac_prototype != = nothing && colorvec === nothing &&
4767+ ArrayInterface. fast_matrix_colors (jac_prototype)
4768+ _colorvec = ArrayInterface. matrix_colors (jac_prototype)
4769+ else
4770+ _colorvec = colorvec
4771+ end
4772+
4773+ jaciip = jac != = nothing ? isinplace (jac, 5 , " jac" , iip) : iip
4774+ controljaciip = controljac != = nothing ? isinplace (controljac, 5 , " controljac" , iip) : iip
4775+ tgradiip = tgrad != = nothing ? isinplace (tgrad, 5 , " tgrad" , iip) : iip
4776+ jvpiip = jvp != = nothing ? isinplace (jvp, 6 , " jvp" , iip) : iip
4777+ vjpiip = vjp != = nothing ? isinplace (vjp, 6 , " vjp" , iip) : iip
4778+ Wfactiip = Wfact != = nothing ? isinplace (Wfact, 6 , " Wfact" , iip) : iip
4779+ Wfact_tiip = Wfact_t != = nothing ? isinplace (Wfact_t, 6 , " Wfact_t" , iip) : iip
4780+ paramjaciip = paramjac != = nothing ? isinplace (paramjac, 5 , " paramjac" , iip) : iip
4781+
4782+ nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
4783+ paramjaciip) .!= iip
4784+ if any (nonconforming)
4785+ nonconforming = findall (nonconforming)
4786+ functions = [" jac" , " tgrad" , " jvp" , " vjp" , " Wfact" , " Wfact_t" , " paramjac" ][nonconforming]
4787+ throw (NonconformingFunctionsError (functions))
4788+ end
4789+
4790+ _f = prepare_function (f)
4791+
4792+ sys = sys_or_symbolcache (sys, syms, paramsyms, indepsym)
4793+ initdata = reconstruct_initialization_data (
4794+ initialization_data, initializeprob, update_initializeprob!,
4795+ initializeprobmap, initializeprobpmap)
4796+
4797+ if specialize === NoSpecialize
4798+ ODEInputFunction{iip, specialize,
4799+ Any, Any, Any, Any,
4800+ Any, Any, Any, Any, typeof (jac_prototype), typeof (controljac_prototype),
4801+ typeof (sparsity), Any, Any, typeof (W_prototype), Any,
4802+ Any,
4803+ typeof (_colorvec),
4804+ typeof (sys), Union{Nothing, OverrideInitData}}(
4805+ _f, mass_matrix, analytic, tgrad, jac, controljac,
4806+ jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4807+ Wfact_t, W_prototype, paramjac,
4808+ observed, _colorvec, sys, initdata)
4809+ elseif specialize === false
4810+ ODEInputFunction{iip, FunctionWrapperSpecialize,
4811+ typeof (_f), typeof (mass_matrix), typeof (analytic), typeof (tgrad),
4812+ typeof (jac), typeof (controljac), typeof (jvp), typeof (vjp), typeof (jac_prototype), typeof (controljac_prototype),
4813+ typeof (sparsity), typeof (Wfact), typeof (Wfact_t), typeof (W_prototype),
4814+ typeof (paramjac),
4815+ typeof (observed),
4816+ typeof (_colorvec),
4817+ typeof (sys), typeof (initdata)}(_f, mass_matrix,
4818+ analytic, tgrad, jac, controljac,
4819+ jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4820+ Wfact_t, W_prototype, paramjac,
4821+ observed, _colorvec, sys, initdata)
4822+ else
4823+ ODEInputFunction{iip, specialize,
4824+ typeof (_f), typeof (mass_matrix), typeof (analytic), typeof (tgrad),
4825+ typeof (jac), typeof (controljac), typeof (jvp), typeof (vjp), typeof (jac_prototype), typeof (controljac_prototype),
4826+ typeof (sparsity), typeof (Wfact), typeof (Wfact_t), typeof (W_prototype),
4827+ typeof (paramjac),
4828+ typeof (observed),
4829+ typeof (_colorvec),
4830+ typeof (sys), typeof (initdata)}(
4831+ _f, mass_matrix, analytic, tgrad,
4832+ jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4833+ Wfact_t, W_prototype, paramjac,
4834+ observed, _colorvec, sys, initdata)
4835+ end
4836+ end
4837+
4838+ function ODEInputFunction {iip} (f; kwargs... ) where {iip}
4839+ ODEInputFunction {iip, FullSpecialize} (f; kwargs... )
4840+ end
4841+ ODEInputFunction {iip} (f:: ODEInputFunction ; kwargs... ) where {iip} = f
4842+ ODEInputFunction (f; kwargs... ) = ODEInputFunction {isinplace(f, 5), FullSpecialize} (f; kwargs... )
4843+ ODEInputFunction (f:: ODEInputFunction ; kwargs... ) = f
4844+
45984845# ######### Utility functions
45994846
46004847function sys_or_symbolcache (sys, syms, paramsyms, indepsym = nothing )
@@ -4628,6 +4875,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t)
46284875__has_W_prototype (f) = isdefined (f, :W_prototype )
46294876__has_paramjac (f) = isdefined (f, :paramjac )
46304877__has_jac_prototype (f) = isdefined (f, :jac_prototype )
4878+ __has_controljac_prototype (f) = isdefined (f, :controljac_prototype )
46314879__has_sparsity (f) = isdefined (f, :sparsity )
46324880__has_mass_matrix (f) = isdefined (f, :mass_matrix )
46334881__has_syms (f) = isdefined (f, :syms )
0 commit comments