Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, D
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
IncrementingODEFunction, NonlinearFunction, HomotopyNonlinearFunction,
IntervalNonlinearFunction, BVPFunction,
DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction
DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction, ControlFunction

export OptimizationFunction, MultiObjectiveOptimizationFunction

Expand Down
2 changes: 1 addition & 1 deletion src/problems/implicit_discrete_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dt: the time step

### Constructors

- `ImplicitDiscreteProblem(f::ODEFunction,u0,tspan,p=NullParameters();kwargs...)` :
- `ImplicitDiscreteProblem(f::ImplicitDiscreteFunction,u0,tspan,p=NullParameters();kwargs...)` :
Defines the discrete problem with the specified functions.
- `ImplicitDiscreteProblem{isinplace,specialize}(f,u0,tspan,p=NullParameters();kwargs...)` :
Defines the discrete problem with the specified functions.
Expand Down
248 changes: 248 additions & 0 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2094,6 +2094,109 @@ struct MultiObjectiveOptimizationFunction{
initialization_data::ID
end

"""
$(TYPEDEF)
"""
abstract type AbstractODEInputFunction{iip} <: AbstractDiffEqFunction{iip} end

@doc doc"""
$(TYPEDEF)

A representation of a ODE function `f` with inputs, defined by:

```math
\frac{dx}{dt} = f(x, u, p, t)
```
where `x` are the states of the system and `u` are the inputs (which may represent
different things in different contexts, such as control variables in optimal control).

Includes all of its related functions, such as the Jacobian of `f`, its gradient
with respect to time, and more. For all cases, `u0` is the initial condition,
`p` are the parameters, and `t` is the independent variable.

```julia
ODEInputFunction{iip, specialize}(f;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
control_jac = __has_controljac(f) ? f.controljac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
controljac_prototype = __has_controljac_prototype(f) ? f.controljac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
syms = nothing,
indepsym = nothing,
paramsyms = nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing)
```

`f` should be given as `f(x_out,x,u,p,t)` or `out = f(x,u,p,t)`.
See the section on `iip` for more details on in-place vs out-of-place handling.

- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used
to determine that the equation is actually a BVP for differential algebraic equation (DAE)
if `M` is singular.
- `jac(J,dx,x,u,p,gamma,t)` or `J=jac(dx,x,u,p,gamma,t)`: returns ``\frac{df}{dx}``
- `control_jac(J,du,x,u,p,gamma,t)` or `J=control_jac(du,x,u,p,gamma,t)`: returns ``\frac{df}{du}``
- `jvp(Jv,v,du,x,u,p,gamma,t)` or `Jv=jvp(v,du,x,u,p,gamma,t)`: returns the directional
derivative ``\frac{df}{du} v``
- `vjp(Jv,v,du,x,u,p,gamma,t)` or `Jv=vjp(v,du,x,u,p,gamma,t)`: returns the adjoint
derivative ``\frac{df}{du}^\ast v``
- `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
as the prototype and integrators will specialize on this structure where possible. Non-structured
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
The default is `nothing`, which means a dense Jacobian.
- `controljac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
as the prototype and integrators will specialize on this structure where possible. Non-structured
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
The default is `nothing`, which means a dense Jacobian.
- `paramjac(pJ,x,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``.
- `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity
pattern of the `jac_prototype`. This specializes the Jacobian construction when using
finite differences and automatic differentiation to be computed in an accelerated manner
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.

## iip: In-Place vs Out-Of-Place
For more details on this argument, see the ODEFunction documentation.

## specialize: Controlling Compilation and Specialization
For more details on this argument, see the ODEFunction documentation.

## Fields
The fields of the ODEInputFunction type directly match the names of the inputs.
"""
struct ODEInputFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV,
SYS, ID} <: AbstractODEInputFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
tgrad::Tt
jac::TJ
controljac::CTJ
jvp::JVP
vjp::VJP
jac_prototype::JP
controljac_prototype::CJP
sparsity::SP
Wfact::TW
Wfact_t::TWt
W_prototype::WP
paramjac::TPJ
observed::O
colorvec::TCV
sys::SYS
initialization_data::ID
end

"""
$(TYPEDEF)
"""
Expand Down Expand Up @@ -2493,6 +2596,7 @@ end
(f::ImplicitDiscreteFunction)(args...) = f.f(args...)
(f::DAEFunction)(args...) = f.f(args...)
(f::DDEFunction)(args...) = f.f(args...)
(f::ODEInputFunction)(args...) = f.f(args...)

function (f::DynamicalDDEFunction)(u, h, p, t)
ArrayPartition(f.f1(u.x[1], u.x[2], h, p, t), f.f2(u.x[1], u.x[2], h, p, t))
Expand Down Expand Up @@ -4595,6 +4699,149 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...)
BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...)
end

function ODEInputFunction{iip, specialize}(f;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
controljac = __has_controljac(f) ? f.controljac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ?
f.jac_prototype :
nothing,
controljac_prototype = __has_controljac_prototype(f) ?
f.controljac_prototype :
nothing,
sparsity = __has_sparsity(f) ? f.sparsity :
jac_prototype,
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing,
W_prototype = __has_W_prototype(f) ? f.W_prototype : nothing,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
syms = nothing,
indepsym = nothing,
paramsyms = nothing,
observed = __has_observed(f) ? f.observed :
DEFAULT_OBSERVED,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
update_initializeprob! = __has_update_initializeprob!(f) ?
f.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
initialization_data = __has_initialization_data(f) ? f.initialization_data :
nothing,
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
) where {iip,
specialize
}
if mass_matrix === I && f isa Tuple
mass_matrix = ((I for i in 1:length(f))...,)
end

if (specialize === FunctionWrapperSpecialize) &&
!(f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
error("FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!")
end

if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
if iip
jac = (J, x, u, p, t) -> update_coefficients!(J, x, p, t) #(J,x,u,p,t)
else
jac = (x, u, p, t) -> update_coefficients(deepcopy(jac_prototype), x, p, t)
end
end

if controljac === nothing && isa(controljac_prototype, AbstractSciMLOperator)
if iip_bc
controljac = (J, x, u, p, t) -> update_coefficients!(J, u, p, t) #(J,x,u,p,t)
else
controljac = (x, u, p, t) -> update_coefficients(deepcopy(controljac_prototype), u, p, t)
end
end

if jac_prototype !== nothing && colorvec === nothing &&
ArrayInterface.fast_matrix_colors(jac_prototype)
_colorvec = ArrayInterface.matrix_colors(jac_prototype)
else
_colorvec = colorvec
end

jaciip = jac !== nothing ? isinplace(jac, 5, "jac", iip) : iip
controljaciip = controljac !== nothing ? isinplace(controljac, 5, "controljac", iip) : iip
tgradiip = tgrad !== nothing ? isinplace(tgrad, 5, "tgrad", iip) : iip
jvpiip = jvp !== nothing ? isinplace(jvp, 6, "jvp", iip) : iip
vjpiip = vjp !== nothing ? isinplace(vjp, 6, "vjp", iip) : iip
Wfactiip = Wfact !== nothing ? isinplace(Wfact, 6, "Wfact", iip) : iip
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 6, "Wfact_t", iip) : iip
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 5, "paramjac", iip) : iip

nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
paramjaciip) .!= iip
if any(nonconforming)
nonconforming = findall(nonconforming)
functions = ["jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming]
throw(NonconformingFunctionsError(functions))
end

_f = prepare_function(f)

sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
initdata = reconstruct_initialization_data(
initialization_data, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)

if specialize === NoSpecialize
ODEInputFunction{iip, specialize,
Any, Any, Any, Any,
Any, Any, Any, Any, typeof(jac_prototype), typeof(controljac_prototype),
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Union{Nothing, OverrideInitData}}(
_f, mass_matrix, analytic, tgrad, jac, controljac,
jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata)
elseif specialize === false
ODEInputFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initdata)}(_f, mass_matrix,
analytic, tgrad, jac, controljac,
jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata)
else
ODEInputFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initdata)}(
_f, mass_matrix, analytic, tgrad,
jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata)
end
end

function ODEInputFunction{iip}(f; kwargs...) where {iip}
ODEInputFunction{iip, FullSpecialize}(f; kwargs...)
end
ODEInputFunction{iip}(f::ODEInputFunction; kwargs...) where {iip} = f
ODEInputFunction(f; kwargs...) = ODEInputFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...)
ODEInputFunction(f::ODEInputFunction; kwargs...) = f

########## Utility functions

function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
Expand Down Expand Up @@ -4628,6 +4875,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t)
__has_W_prototype(f) = isdefined(f, :W_prototype)
__has_paramjac(f) = isdefined(f, :paramjac)
__has_jac_prototype(f) = isdefined(f, :jac_prototype)
__has_controljac_prototype(f) = isdefined(f, :controljac_prototype)
__has_sparsity(f) = isdefined(f, :sparsity)
__has_mass_matrix(f) = isdefined(f, :mass_matrix)
__has_syms(f) = isdefined(f, :syms)
Expand Down
Empty file added src/solutions/solution_utils.jl
Empty file.
Loading