diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 92bc2d8c3..f8e4c903d 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -822,7 +822,7 @@ export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, D DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction, IncrementingODEFunction, NonlinearFunction, HomotopyNonlinearFunction, IntervalNonlinearFunction, BVPFunction, - DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction + DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction, ODEInputFunction export OptimizationFunction, MultiObjectiveOptimizationFunction diff --git a/src/problems/implicit_discrete_problems.jl b/src/problems/implicit_discrete_problems.jl index ae892e239..515bd48fb 100644 --- a/src/problems/implicit_discrete_problems.jl +++ b/src/problems/implicit_discrete_problems.jl @@ -27,7 +27,7 @@ dt: the time step ### Constructors -- `ImplicitDiscreteProblem(f::ODEFunction,u0,tspan,p=NullParameters();kwargs...)` : +- `ImplicitDiscreteProblem(f::ImplicitDiscreteFunction,u0,tspan,p=NullParameters();kwargs...)` : Defines the discrete problem with the specified functions. - `ImplicitDiscreteProblem{isinplace,specialize}(f,u0,tspan,p=NullParameters();kwargs...)` : Defines the discrete problem with the specified functions. diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 867efe858..1b61bd491 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2094,6 +2094,109 @@ struct MultiObjectiveOptimizationFunction{ initialization_data::ID end +""" +$(TYPEDEF) +""" +abstract type AbstractODEInputFunction{iip} <: AbstractDiffEqFunction{iip} end + +@doc doc""" +$(TYPEDEF) + +A representation of a ODE function `f` with inputs, defined by: + +```math +\frac{dx}{dt} = f(x, u, p, t) +``` +where `x` are the states of the system and `u` are the inputs (which may represent +different things in different contexts, such as control variables in optimal control). + +Includes all of its related functions, such as the Jacobian of `f`, its gradient +with respect to time, and more. For all cases, `u0` is the initial condition, +`p` are the parameters, and `t` is the independent variable. + +```julia +ODEInputFunction{iip, specialize}(f; + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, + analytic = __has_analytic(f) ? f.analytic : nothing, + tgrad= __has_tgrad(f) ? f.tgrad : nothing, + jac = __has_jac(f) ? f.jac : nothing, + control_jac = __has_controljac(f) ? f.controljac : nothing, + jvp = __has_jvp(f) ? f.jvp : nothing, + vjp = __has_vjp(f) ? f.vjp : nothing, + jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + controljac_prototype = __has_controljac_prototype(f) ? f.controljac_prototype : nothing, + sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + syms = nothing, + indepsym = nothing, + paramsyms = nothing, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + sys = __has_sys(f) ? f.sys : nothing) +``` + +`f` should be given as `f(x_out,x,u,p,t)` or `out = f(x,u,p,t)`. +See the section on `iip` for more details on in-place vs out-of-place handling. + +- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used + to determine that the equation is actually a BVP for differential algebraic equation (DAE) + if `M` is singular. +- `jac(J,dx,x,u,p,gamma,t)` or `J=jac(dx,x,u,p,gamma,t)`: returns ``\frac{df}{dx}`` +- `control_jac(J,du,x,u,p,gamma,t)` or `J=control_jac(du,x,u,p,gamma,t)`: returns ``\frac{df}{du}`` +- `jvp(Jv,v,du,x,u,p,gamma,t)` or `Jv=jvp(v,du,x,u,p,gamma,t)`: returns the directional + derivative ``\frac{df}{du} v`` +- `vjp(Jv,v,du,x,u,p,gamma,t)` or `Jv=vjp(v,du,x,u,p,gamma,t)`: returns the adjoint + derivative ``\frac{df}{du}^\ast v`` +- `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example, + if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used + as the prototype and integrators will specialize on this structure where possible. Non-structured + sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. + The default is `nothing`, which means a dense Jacobian. +- `controljac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example, + if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used + as the prototype and integrators will specialize on this structure where possible. Non-structured + sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. + The default is `nothing`, which means a dense Jacobian. +- `paramjac(pJ,x,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``. +- `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity + pattern of the `jac_prototype`. This specializes the Jacobian construction when using + finite differences and automatic differentiation to be computed in an accelerated manner + based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be + internally computed on demand when required. The cost of this operation is highly dependent + on the sparsity pattern. + +## iip: In-Place vs Out-Of-Place +For more details on this argument, see the ODEFunction documentation. + +## specialize: Controlling Compilation and Specialization +For more details on this argument, see the ODEFunction documentation. + +## Fields +The fields of the ODEInputFunction type directly match the names of the inputs. +""" +struct ODEInputFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP, + JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV, + SYS, ID} <: AbstractODEInputFunction{iip} + f::F + mass_matrix::TMM + analytic::Ta + tgrad::Tt + jac::TJ + controljac::CTJ + jvp::JVP + vjp::VJP + jac_prototype::JP + controljac_prototype::CJP + sparsity::SP + Wfact::TW + Wfact_t::TWt + W_prototype::WP + paramjac::TPJ + observed::O + colorvec::TCV + sys::SYS + initialization_data::ID +end + """ $(TYPEDEF) """ @@ -2493,6 +2596,7 @@ end (f::ImplicitDiscreteFunction)(args...) = f.f(args...) (f::DAEFunction)(args...) = f.f(args...) (f::DDEFunction)(args...) = f.f(args...) +(f::ODEInputFunction)(args...) = f.f(args...) function (f::DynamicalDDEFunction)(u, h, p, t) ArrayPartition(f.f1(u.x[1], u.x[2], h, p, t), f.f2(u.x[1], u.x[2], h, p, t)) @@ -4595,6 +4699,149 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...) BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...) end +function ODEInputFunction{iip, specialize}(f; + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : + I, + analytic = __has_analytic(f) ? f.analytic : nothing, + tgrad = __has_tgrad(f) ? f.tgrad : nothing, + jac = __has_jac(f) ? f.jac : nothing, + controljac = __has_controljac(f) ? f.controljac : nothing, + jvp = __has_jvp(f) ? f.jvp : nothing, + vjp = __has_vjp(f) ? f.vjp : nothing, + jac_prototype = __has_jac_prototype(f) ? + f.jac_prototype : + nothing, + controljac_prototype = __has_controljac_prototype(f) ? + f.controljac_prototype : + nothing, + sparsity = __has_sparsity(f) ? f.sparsity : + jac_prototype, + Wfact = __has_Wfact(f) ? f.Wfact : nothing, + Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing, + W_prototype = __has_W_prototype(f) ? f.W_prototype : nothing, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + syms = nothing, + indepsym = nothing, + paramsyms = nothing, + observed = __has_observed(f) ? f.observed : + DEFAULT_OBSERVED, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + sys = __has_sys(f) ? f.sys : nothing, + initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, + update_initializeprob! = __has_update_initializeprob!(f) ? + f.update_initializeprob! : nothing, + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, + initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing, + initialization_data = __has_initialization_data(f) ? f.initialization_data : + nothing, + nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing +) where {iip, + specialize +} + if mass_matrix === I && f isa Tuple + mass_matrix = ((I for i in 1:length(f))...,) + end + + if (specialize === FunctionWrapperSpecialize) && + !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper) + error("FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!") + end + + if jac === nothing && isa(jac_prototype, AbstractSciMLOperator) + if iip + jac = (J, x, u, p, t) -> update_coefficients!(J, x, p, t) #(J,x,u,p,t) + else + jac = (x, u, p, t) -> update_coefficients(deepcopy(jac_prototype), x, p, t) + end + end + + if controljac === nothing && isa(controljac_prototype, AbstractSciMLOperator) + if iip_bc + controljac = (J, x, u, p, t) -> update_coefficients!(J, u, p, t) #(J,x,u,p,t) + else + controljac = (x, u, p, t) -> update_coefficients(deepcopy(controljac_prototype), u, p, t) + end + end + + if jac_prototype !== nothing && colorvec === nothing && + ArrayInterface.fast_matrix_colors(jac_prototype) + _colorvec = ArrayInterface.matrix_colors(jac_prototype) + else + _colorvec = colorvec + end + + jaciip = jac !== nothing ? isinplace(jac, 5, "jac", iip) : iip + controljaciip = controljac !== nothing ? isinplace(controljac, 5, "controljac", iip) : iip + tgradiip = tgrad !== nothing ? isinplace(tgrad, 5, "tgrad", iip) : iip + jvpiip = jvp !== nothing ? isinplace(jvp, 6, "jvp", iip) : iip + vjpiip = vjp !== nothing ? isinplace(vjp, 6, "vjp", iip) : iip + Wfactiip = Wfact !== nothing ? isinplace(Wfact, 6, "Wfact", iip) : iip + Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 6, "Wfact_t", iip) : iip + paramjaciip = paramjac !== nothing ? isinplace(paramjac, 5, "paramjac", iip) : iip + + nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip, + paramjaciip) .!= iip + if any(nonconforming) + nonconforming = findall(nonconforming) + functions = ["jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming] + throw(NonconformingFunctionsError(functions)) + end + + _f = prepare_function(f) + + sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) + + if specialize === NoSpecialize + ODEInputFunction{iip, specialize, + Any, Any, Any, Any, + Any, Any, Any, Any, typeof(jac_prototype), typeof(controljac_prototype), + typeof(sparsity), Any, Any, typeof(W_prototype), Any, + Any, + typeof(_colorvec), + typeof(sys), Union{Nothing, OverrideInitData}}( + _f, mass_matrix, analytic, tgrad, jac, controljac, + jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact, + Wfact_t, W_prototype, paramjac, + observed, _colorvec, sys, initdata) + elseif specialize === false + ODEInputFunction{iip, FunctionWrapperSpecialize, + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), + typeof(paramjac), + typeof(observed), + typeof(_colorvec), + typeof(sys), typeof(initdata)}(_f, mass_matrix, + analytic, tgrad, jac, controljac, + jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact, + Wfact_t, W_prototype, paramjac, + observed, _colorvec, sys, initdata) + else + ODEInputFunction{iip, specialize, + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), + typeof(paramjac), + typeof(observed), + typeof(_colorvec), + typeof(sys), typeof(initdata)}( + _f, mass_matrix, analytic, tgrad, + jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact, + Wfact_t, W_prototype, paramjac, + observed, _colorvec, sys, initdata) + end +end + +function ODEInputFunction{iip}(f; kwargs...) where {iip} + ODEInputFunction{iip, FullSpecialize}(f; kwargs...) +end +ODEInputFunction{iip}(f::ODEInputFunction; kwargs...) where {iip} = f +ODEInputFunction(f; kwargs...) = ODEInputFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...) +ODEInputFunction(f::ODEInputFunction; kwargs...) = f + ########## Utility functions function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing) @@ -4628,6 +4875,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t) __has_W_prototype(f) = isdefined(f, :W_prototype) __has_paramjac(f) = isdefined(f, :paramjac) __has_jac_prototype(f) = isdefined(f, :jac_prototype) +__has_controljac_prototype(f) = isdefined(f, :controljac_prototype) __has_sparsity(f) = isdefined(f, :sparsity) __has_mass_matrix(f) = isdefined(f, :mass_matrix) __has_syms(f) = isdefined(f, :syms) diff --git a/src/solutions/solution_utils.jl b/src/solutions/solution_utils.jl new file mode 100644 index 000000000..e69de29bb