From 24d40964817fe3b2d2c12c8fc8a45e5fd61a134e Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 16 Apr 2025 16:02:41 -0400 Subject: [PATCH 01/10] feat: add ControlFunction --- src/interpolation.jl | 24 ++++++++ src/scimlfunctions.jl | 103 ++++++++++++++++++++++++++++++++ src/solutions/solution_utils.jl | 0 3 files changed, 127 insertions(+) create mode 100644 src/solutions/solution_utils.jl diff --git a/src/interpolation.jl b/src/interpolation.jl index 4b08e4eb5f..b022b78a1b 100644 --- a/src/interpolation.jl +++ b/src/interpolation.jl @@ -619,3 +619,27 @@ strip_interpolation(id::AbstractDiffEqInterpolation) = id strip_interpolation(id::HermiteInterpolation) = id strip_interpolation(id::LinearInterpolation) = id strip_interpolation(id::ConstantInterpolation) = id + +""" +Return the maximum value of a solution trajectory. Uses the interpolating polynomial +to compute the maximum (i.e. is not simply the largest value in the sol.u array.) +""" +function maxsol(sol::AbstractODESolution) + +end + +""" +Return the minimum value of a solution trajectory. Uses the interpolating polynomial +to compute the minimum (i.e. is not simply the smallest value in the sol.u array.) +""" +function minsol(sol::AbstractODESolution) + +end + +""" + +""" +function integralnorm(sol::AbstractODESolution) + +end + diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 867efe858d..7484290dbd 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2094,6 +2094,109 @@ struct MultiObjectiveOptimizationFunction{ initialization_data::ID end +""" +$(TYPEDEF) +""" +abstract type AbstractControlFunction{iip} <: AbstractDiffEqFunction{iip} end + +@doc doc""" +$(TYPEDEF) + +A representation of a optimal control function `f`, defined by: + +```math +\frac{dx}{dt} = f(x, u, p, t) +``` +where `x` are the states of the system and `u` are the inputs (or control variables). + +and all of its related functions, such as the Jacobian of `f`, its gradient +with respect to time, and more. For all cases, `u0` is the initial condition, +`p` are the parameters, and `t` is the independent variable. + +```julia +ControlFunction{iip, specialize}(f; + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, + analytic = __has_analytic(f) ? f.analytic : nothing, + tgrad= __has_tgrad(f) ? f.tgrad : nothing, + jac = __has_jac(f) ? f.jac : nothing, + control_jac = __has_controljac(f) ? f.controljac : nothing, + jvp = __has_jvp(f) ? f.jvp : nothing, + vjp = __has_vjp(f) ? f.vjp : nothing, + jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + controljac_prototype = __has_controljac_prototype(f) ? f.controljac_prototype : nothing, + sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + syms = nothing, + indepsym = nothing, + paramsyms = nothing, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + sys = __has_sys(f) ? f.sys : nothing) +``` + +`f` should be given as `f(x_out,x,u,p,t)` or `out = f(x,u,p,t)`. +See the section on `iip` for more details on in-place vs out-of-place handling. + +- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used + to determine that the equation is actually a BVP for differential algebraic equation (DAE) + if `M` is singular. +- `jac(J,dx,x,p,gamma,t)` or `J=jac(dx,x,p,gamma,t)`: returns ``\frac{df}{dx}`` +- `control_jac(J,du,u,p,gamma,t)` or `J=control_jac(du,u,p,gamma,t)`: returns ``\frac{df}{du}`` +- `jvp(Jv,v,du,u,p,gamma,t)` or `Jv=jvp(v,du,u,p,gamma,t)`: returns the directional + derivative ``\frac{df}{du} v`` +- `vjp(Jv,v,du,u,p,gamma,t)` or `Jv=vjp(v,du,u,p,gamma,t)`: returns the adjoint + derivative ``\frac{df}{du}^\ast v`` +- `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example, + if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used + as the prototype and integrators will specialize on this structure where possible. Non-structured + sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. + The default is `nothing`, which means a dense Jacobian. +- `controljac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example, + if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used + as the prototype and integrators will specialize on this structure where possible. Non-structured + sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. + The default is `nothing`, which means a dense Jacobian. +- `paramjac(pJ,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``. +- `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity + pattern of the `jac_prototype`. This specializes the Jacobian construction when using + finite differences and automatic differentiation to be computed in an accelerated manner + based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be + internally computed on demand when required. The cost of this operation is highly dependent + on the sparsity pattern. + +## iip: In-Place vs Out-Of-Place + +For more details on this argument, see the ODEFunction documentation. + +## specialize: Controlling Compilation and Specialization + +For more details on this argument, see the ODEFunction documentation. + +## Fields +# +The fields of the ControlFunction type directly match the names of the inputs. +""" +struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP, + JP, CJP, SP, TPJ, O, TCV, CTCV, + SYS, ID} <: AbstractControlFunction{iip} + f::F + mass_matrix::TMM + analytic::Ta + tgrad::Tt + jac::TJ + controljac::CTJ + jvp::JVP + vjp::VJP + jac_prototype::JP + controljac_prototype::CJP + sparsity::SP + paramjac::TPJ + observed::O + colorvec::TCV + controlcolorvec::CTCV + sys::SYS + initialization_data::ID +end + """ $(TYPEDEF) """ diff --git a/src/solutions/solution_utils.jl b/src/solutions/solution_utils.jl new file mode 100644 index 0000000000..e69de29bb2 From 9b63c96cbcab5801861beeb5393ec28fdd7a8869 Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 16 Apr 2025 17:08:06 -0400 Subject: [PATCH 02/10] add ControlFunction constructor --- src/problems/implicit_discrete_problems.jl | 2 +- src/scimlfunctions.jl | 147 ++++++++++++++++++++- 2 files changed, 146 insertions(+), 3 deletions(-) diff --git a/src/problems/implicit_discrete_problems.jl b/src/problems/implicit_discrete_problems.jl index ae892e239d..515bd48fb7 100644 --- a/src/problems/implicit_discrete_problems.jl +++ b/src/problems/implicit_discrete_problems.jl @@ -27,7 +27,7 @@ dt: the time step ### Constructors -- `ImplicitDiscreteProblem(f::ODEFunction,u0,tspan,p=NullParameters();kwargs...)` : +- `ImplicitDiscreteProblem(f::ImplicitDiscreteFunction,u0,tspan,p=NullParameters();kwargs...)` : Defines the discrete problem with the specified functions. - `ImplicitDiscreteProblem{isinplace,specialize}(f,u0,tspan,p=NullParameters();kwargs...)` : Defines the discrete problem with the specified functions. diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 7484290dbd..60ca15ae43 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2176,7 +2176,7 @@ For more details on this argument, see the ODEFunction documentation. The fields of the ControlFunction type directly match the names of the inputs. """ struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP, - JP, CJP, SP, TPJ, O, TCV, CTCV, + JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV, CTCV, SYS, ID} <: AbstractControlFunction{iip} f::F mass_matrix::TMM @@ -2189,10 +2189,12 @@ struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP, jac_prototype::JP controljac_prototype::CJP sparsity::SP + Wfact::TW + Wfact_t::TWt + W_prototype::WP paramjac::TPJ observed::O colorvec::TCV - controlcolorvec::CTCV sys::SYS initialization_data::ID end @@ -4698,6 +4700,146 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...) BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...) end +function ControlFunction{iip, specialize}(f; + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : + I, + analytic = __has_analytic(f) ? f.analytic : nothing, + tgrad = __has_tgrad(f) ? f.tgrad : nothing, + jac = __has_jac(f) ? f.jac : nothing, + controljac = __has_controljac(f) ? f.controljac : nothing, + jvp = __has_jvp(f) ? f.jvp : nothing, + vjp = __has_vjp(f) ? f.vjp : nothing, + jac_prototype = __has_jac_prototype(f) ? + f.jac_prototype : + nothing, + controljac_prototype = __has_controljac_prototype(f) ? + f.controljac_prototype : + nothing, + sparsity = __has_sparsity(f) ? f.sparsity : + jac_prototype, + Wfact = __has_Wfact(f) ? f.Wfact : nothing, + Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing, + W_prototype = __has_W_prototype(f) ? f.W_prototype : nothing, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + observed = __has_observed(f) ? f.observed : + DEFAULT_OBSERVED, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + sys = __has_sys(f) ? f.sys : nothing, + initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, + update_initializeprob! = __has_update_initializeprob!(f) ? + f.update_initializeprob! : nothing, + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, + initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing, + initialization_data = __has_initialization_data(f) ? f.initialization_data : + nothing, + nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing +) where {iip, + specialize +} + if mass_matrix === I && f isa Tuple + mass_matrix = ((I for i in 1:length(f))...,) + end + + if (specialize === FunctionWrapperSpecialize) && + !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper) + error("FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!") + end + + if jac === nothing && isa(jac_prototype, AbstractSciMLOperator) + if iip + jac = update_coefficients! #(J,u,p,t) + else + jac = (u, p, t) -> update_coefficients(deepcopy(jac_prototype), u, p, t) + end + end + + if controljac === nothing && isa(controljac_prototype, AbstractSciMLOperator) + if iip_bc + controljac = update_coefficients! #(J,u,p,t) + else + controljac = (u, p, t) -> update_coefficients!(deepcopy(controljac_prototype), u, p, t) + end + end + + if jac_prototype !== nothing && colorvec === nothing && + ArrayInterface.fast_matrix_colors(jac_prototype) + _colorvec = ArrayInterface.matrix_colors(jac_prototype) + else + _colorvec = colorvec + end + + jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip + controljaciip = controljac !== nothing ? isinplace(controljac, 4, "controljac", iip) : iip + tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip + jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip + vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip + Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip) : iip + Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip + paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip + + nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip, + paramjaciip) .!= iip + if any(nonconforming) + nonconforming = findall(nonconforming) + functions = ["jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming] + throw(NonconformingFunctionsError(functions)) + end + + _f = prepare_function(f) + + sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) + initdata = reconstruct_initialization_data( + initialization_data, initializeprob, update_initializeprob!, + initializeprobmap, initializeprobpmap) + + if specialize === NoSpecialize + ControlFunction{iip, specialize, + Any, Any, Any, Any, + Any, Any, Any, Any, typeof(jac_prototype), typeof(controljac_prototype), + typeof(sparsity), Any, Any, typeof(W_prototype), Any, + Any, + typeof(_colorvec), + typeof(sys), Union{Nothing, OverrideInitData}}( + _f, mass_matrix, analytic, tgrad, jac, controljac, + jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact, + Wfact_t, W_prototype, paramjac, + observed, _colorvec, sys, initdata) + elseif specialize === false + ControlFunction{iip, FunctionWrapperSpecialize, + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), + typeof(paramjac), + typeof(observed), + typeof(_colorvec), + typeof(sys), typeof(initdata)}(_f, mass_matrix, + analytic, tgrad, jac, controljac, + jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact, + Wfact_t, W_prototype, paramjac, + observed, _colorvec, sys, initdata) + else + ControlFunction{iip, specialize, + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), + typeof(paramjac), + typeof(observed), + typeof(_colorvec), + typeof(sys), typeof(initdata)}( + _f, mass_matrix, analytic, tgrad, + jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact, + Wfact_t, W_prototype, paramjac, + observed, _colorvec, sys, initdata) + end +end + +function ODEFunction{iip}(f; kwargs...) where {iip} + ODEFunction{iip, FullSpecialize}(f; kwargs...) +end +ODEFunction{iip}(f::ODEFunction; kwargs...) where {iip} = f +ODEFunction(f; kwargs...) = ODEFunction{isinplace(f, 4), FullSpecialize}(f; kwargs...) +ODEFunction(f::ODEFunction; kwargs...) = f + ########## Utility functions function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing) @@ -4731,6 +4873,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t) __has_W_prototype(f) = isdefined(f, :W_prototype) __has_paramjac(f) = isdefined(f, :paramjac) __has_jac_prototype(f) = isdefined(f, :jac_prototype) +__has_controljac_prototype(f) = isdefined(f, :controljac_prototype) __has_sparsity(f) = isdefined(f, :sparsity) __has_mass_matrix(f) = isdefined(f, :mass_matrix) __has_syms(f) = isdefined(f, :syms) From 446dc07472e5dbcf69497348fff25a1880b3ec06 Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 16 Apr 2025 17:27:18 -0400 Subject: [PATCH 03/10] add export --- src/SciMLBase.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 92bc2d8c3a..6b0537396b 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -822,7 +822,7 @@ export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, D DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction, IncrementingODEFunction, NonlinearFunction, HomotopyNonlinearFunction, IntervalNonlinearFunction, BVPFunction, - DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction + DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction, ControlFunction export OptimizationFunction, MultiObjectiveOptimizationFunction From 8119fa83c2d01c663fd93af02a0a87b5120b987a Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 16 Apr 2025 17:36:10 -0400 Subject: [PATCH 04/10] fix: fix ControlFunction constructors --- src/scimlfunctions.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 60ca15ae43..c59277a93b 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4833,12 +4833,12 @@ function ControlFunction{iip, specialize}(f; end end -function ODEFunction{iip}(f; kwargs...) where {iip} - ODEFunction{iip, FullSpecialize}(f; kwargs...) +function ControlFunction{iip}(f; kwargs...) where {iip} + ControlFunction{iip, FullSpecialize}(f; kwargs...) end -ODEFunction{iip}(f::ODEFunction; kwargs...) where {iip} = f -ODEFunction(f; kwargs...) = ODEFunction{isinplace(f, 4), FullSpecialize}(f; kwargs...) -ODEFunction(f::ODEFunction; kwargs...) = f +ControlFunction{iip}(f::ControlFunction; kwargs...) where {iip} = f +ControlFunction(f; kwargs...) = ControlFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...) +ControlFunction(f::ControlFunction; kwargs...) = f ########## Utility functions From a12118a1681402307f85d695c5f3df99c8964915 Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 16 Apr 2025 18:07:22 -0400 Subject: [PATCH 05/10] doc: fix docstring --- src/scimlfunctions.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index c59277a93b..9e8da28d1e 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2109,7 +2109,7 @@ A representation of a optimal control function `f`, defined by: ``` where `x` are the states of the system and `u` are the inputs (or control variables). -and all of its related functions, such as the Jacobian of `f`, its gradient +Includes all of its related functions, such as the Jacobian of `f`, its gradient with respect to time, and more. For all cases, `u0` is the initial condition, `p` are the parameters, and `t` is the independent variable. @@ -2164,15 +2164,12 @@ See the section on `iip` for more details on in-place vs out-of-place handling. on the sparsity pattern. ## iip: In-Place vs Out-Of-Place - For more details on this argument, see the ODEFunction documentation. ## specialize: Controlling Compilation and Specialization - For more details on this argument, see the ODEFunction documentation. ## Fields -# The fields of the ControlFunction type directly match the names of the inputs. """ struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP, From d6fdf9d226dc2c80ec0788862c73c6561c9608b7 Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 16 Apr 2025 18:10:20 -0400 Subject: [PATCH 06/10] fix function signature --- src/scimlfunctions.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 9e8da28d1e..28947b52e1 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4718,6 +4718,9 @@ function ControlFunction{iip, specialize}(f; Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing, W_prototype = __has_W_prototype(f) ? f.W_prototype : nothing, paramjac = __has_paramjac(f) ? f.paramjac : nothing, + syms = nothing, + indepsym = nothing, + paramsyms = nothing, observed = __has_observed(f) ? f.observed : DEFAULT_OBSERVED, colorvec = __has_colorvec(f) ? f.colorvec : nothing, From f4add67b766450fb402b1f6626e01af23874b4cc Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 16 Apr 2025 18:29:44 -0400 Subject: [PATCH 07/10] add dispatch --- src/scimlfunctions.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 28947b52e1..cf010751b7 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2173,7 +2173,7 @@ For more details on this argument, see the ODEFunction documentation. The fields of the ControlFunction type directly match the names of the inputs. """ struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP, - JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV, CTCV, + JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV, SYS, ID} <: AbstractControlFunction{iip} f::F mass_matrix::TMM @@ -2595,6 +2595,7 @@ end (f::ImplicitDiscreteFunction)(args...) = f.f(args...) (f::DAEFunction)(args...) = f.f(args...) (f::DDEFunction)(args...) = f.f(args...) +(f::ControlFunction)(args...) = f.f(args...) function (f::DynamicalDDEFunction)(u, h, p, t) ArrayPartition(f.f1(u.x[1], u.x[2], h, p, t), f.f2(u.x[1], u.x[2], h, p, t)) From 1f14548a53430d3c61b89a4bb6215ef32a768e13 Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 16 Apr 2025 18:31:27 -0400 Subject: [PATCH 08/10] remove interpolation stuff --- src/interpolation.jl | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/interpolation.jl b/src/interpolation.jl index b022b78a1b..4b08e4eb5f 100644 --- a/src/interpolation.jl +++ b/src/interpolation.jl @@ -619,27 +619,3 @@ strip_interpolation(id::AbstractDiffEqInterpolation) = id strip_interpolation(id::HermiteInterpolation) = id strip_interpolation(id::LinearInterpolation) = id strip_interpolation(id::ConstantInterpolation) = id - -""" -Return the maximum value of a solution trajectory. Uses the interpolating polynomial -to compute the maximum (i.e. is not simply the largest value in the sol.u array.) -""" -function maxsol(sol::AbstractODESolution) - -end - -""" -Return the minimum value of a solution trajectory. Uses the interpolating polynomial -to compute the minimum (i.e. is not simply the smallest value in the sol.u array.) -""" -function minsol(sol::AbstractODESolution) - -end - -""" - -""" -function integralnorm(sol::AbstractODESolution) - -end - From df3fb57f29a7eb4fe1b7f55be91171bcd23933a2 Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 21 Apr 2025 10:48:53 -0400 Subject: [PATCH 09/10] fix: correct argument number and signatures --- src/scimlfunctions.jl | 69 ++++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index cf010751b7..1b61bd491d 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2097,24 +2097,25 @@ end """ $(TYPEDEF) """ -abstract type AbstractControlFunction{iip} <: AbstractDiffEqFunction{iip} end +abstract type AbstractODEInputFunction{iip} <: AbstractDiffEqFunction{iip} end @doc doc""" $(TYPEDEF) -A representation of a optimal control function `f`, defined by: +A representation of a ODE function `f` with inputs, defined by: ```math \frac{dx}{dt} = f(x, u, p, t) ``` -where `x` are the states of the system and `u` are the inputs (or control variables). +where `x` are the states of the system and `u` are the inputs (which may represent +different things in different contexts, such as control variables in optimal control). Includes all of its related functions, such as the Jacobian of `f`, its gradient with respect to time, and more. For all cases, `u0` is the initial condition, `p` are the parameters, and `t` is the independent variable. ```julia -ControlFunction{iip, specialize}(f; +ODEInputFunction{iip, specialize}(f; mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, analytic = __has_analytic(f) ? f.analytic : nothing, tgrad= __has_tgrad(f) ? f.tgrad : nothing, @@ -2139,11 +2140,11 @@ See the section on `iip` for more details on in-place vs out-of-place handling. - `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used to determine that the equation is actually a BVP for differential algebraic equation (DAE) if `M` is singular. -- `jac(J,dx,x,p,gamma,t)` or `J=jac(dx,x,p,gamma,t)`: returns ``\frac{df}{dx}`` -- `control_jac(J,du,u,p,gamma,t)` or `J=control_jac(du,u,p,gamma,t)`: returns ``\frac{df}{du}`` -- `jvp(Jv,v,du,u,p,gamma,t)` or `Jv=jvp(v,du,u,p,gamma,t)`: returns the directional +- `jac(J,dx,x,u,p,gamma,t)` or `J=jac(dx,x,u,p,gamma,t)`: returns ``\frac{df}{dx}`` +- `control_jac(J,du,x,u,p,gamma,t)` or `J=control_jac(du,x,u,p,gamma,t)`: returns ``\frac{df}{du}`` +- `jvp(Jv,v,du,x,u,p,gamma,t)` or `Jv=jvp(v,du,x,u,p,gamma,t)`: returns the directional derivative ``\frac{df}{du} v`` -- `vjp(Jv,v,du,u,p,gamma,t)` or `Jv=vjp(v,du,u,p,gamma,t)`: returns the adjoint +- `vjp(Jv,v,du,x,u,p,gamma,t)` or `Jv=vjp(v,du,x,u,p,gamma,t)`: returns the adjoint derivative ``\frac{df}{du}^\ast v`` - `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example, if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used @@ -2155,7 +2156,7 @@ See the section on `iip` for more details on in-place vs out-of-place handling. as the prototype and integrators will specialize on this structure where possible. Non-structured sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. The default is `nothing`, which means a dense Jacobian. -- `paramjac(pJ,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``. +- `paramjac(pJ,x,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``. - `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity pattern of the `jac_prototype`. This specializes the Jacobian construction when using finite differences and automatic differentiation to be computed in an accelerated manner @@ -2170,11 +2171,11 @@ For more details on this argument, see the ODEFunction documentation. For more details on this argument, see the ODEFunction documentation. ## Fields -The fields of the ControlFunction type directly match the names of the inputs. +The fields of the ODEInputFunction type directly match the names of the inputs. """ -struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP, +struct ODEInputFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP, JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS, ID} <: AbstractControlFunction{iip} + SYS, ID} <: AbstractODEInputFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -2595,7 +2596,7 @@ end (f::ImplicitDiscreteFunction)(args...) = f.f(args...) (f::DAEFunction)(args...) = f.f(args...) (f::DDEFunction)(args...) = f.f(args...) -(f::ControlFunction)(args...) = f.f(args...) +(f::ODEInputFunction)(args...) = f.f(args...) function (f::DynamicalDDEFunction)(u, h, p, t) ArrayPartition(f.f1(u.x[1], u.x[2], h, p, t), f.f2(u.x[1], u.x[2], h, p, t)) @@ -4698,7 +4699,7 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...) BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...) end -function ControlFunction{iip, specialize}(f; +function ODEInputFunction{iip, specialize}(f; mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, analytic = __has_analytic(f) ? f.analytic : nothing, @@ -4748,17 +4749,17 @@ function ControlFunction{iip, specialize}(f; if jac === nothing && isa(jac_prototype, AbstractSciMLOperator) if iip - jac = update_coefficients! #(J,u,p,t) + jac = (J, x, u, p, t) -> update_coefficients!(J, x, p, t) #(J,x,u,p,t) else - jac = (u, p, t) -> update_coefficients(deepcopy(jac_prototype), u, p, t) + jac = (x, u, p, t) -> update_coefficients(deepcopy(jac_prototype), x, p, t) end end if controljac === nothing && isa(controljac_prototype, AbstractSciMLOperator) if iip_bc - controljac = update_coefficients! #(J,u,p,t) + controljac = (J, x, u, p, t) -> update_coefficients!(J, u, p, t) #(J,x,u,p,t) else - controljac = (u, p, t) -> update_coefficients!(deepcopy(controljac_prototype), u, p, t) + controljac = (x, u, p, t) -> update_coefficients(deepcopy(controljac_prototype), u, p, t) end end @@ -4769,14 +4770,14 @@ function ControlFunction{iip, specialize}(f; _colorvec = colorvec end - jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip - controljaciip = controljac !== nothing ? isinplace(controljac, 4, "controljac", iip) : iip - tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip - jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip - vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip - Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip) : iip - Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip - paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip + jaciip = jac !== nothing ? isinplace(jac, 5, "jac", iip) : iip + controljaciip = controljac !== nothing ? isinplace(controljac, 5, "controljac", iip) : iip + tgradiip = tgrad !== nothing ? isinplace(tgrad, 5, "tgrad", iip) : iip + jvpiip = jvp !== nothing ? isinplace(jvp, 6, "jvp", iip) : iip + vjpiip = vjp !== nothing ? isinplace(vjp, 6, "vjp", iip) : iip + Wfactiip = Wfact !== nothing ? isinplace(Wfact, 6, "Wfact", iip) : iip + Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 6, "Wfact_t", iip) : iip + paramjaciip = paramjac !== nothing ? isinplace(paramjac, 5, "paramjac", iip) : iip nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip, paramjaciip) .!= iip @@ -4794,7 +4795,7 @@ function ControlFunction{iip, specialize}(f; initializeprobmap, initializeprobpmap) if specialize === NoSpecialize - ControlFunction{iip, specialize, + ODEInputFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, typeof(jac_prototype), typeof(controljac_prototype), typeof(sparsity), Any, Any, typeof(W_prototype), Any, @@ -4806,7 +4807,7 @@ function ControlFunction{iip, specialize}(f; Wfact_t, W_prototype, paramjac, observed, _colorvec, sys, initdata) elseif specialize === false - ControlFunction{iip, FunctionWrapperSpecialize, + ODEInputFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), @@ -4819,7 +4820,7 @@ function ControlFunction{iip, specialize}(f; Wfact_t, W_prototype, paramjac, observed, _colorvec, sys, initdata) else - ControlFunction{iip, specialize, + ODEInputFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), @@ -4834,12 +4835,12 @@ function ControlFunction{iip, specialize}(f; end end -function ControlFunction{iip}(f; kwargs...) where {iip} - ControlFunction{iip, FullSpecialize}(f; kwargs...) +function ODEInputFunction{iip}(f; kwargs...) where {iip} + ODEInputFunction{iip, FullSpecialize}(f; kwargs...) end -ControlFunction{iip}(f::ControlFunction; kwargs...) where {iip} = f -ControlFunction(f; kwargs...) = ControlFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...) -ControlFunction(f::ControlFunction; kwargs...) = f +ODEInputFunction{iip}(f::ODEInputFunction; kwargs...) where {iip} = f +ODEInputFunction(f; kwargs...) = ODEInputFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...) +ODEInputFunction(f::ODEInputFunction; kwargs...) = f ########## Utility functions From 97115af910c70106079fcf49a12b22bd47dbf17d Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 21 Apr 2025 11:23:53 -0400 Subject: [PATCH 10/10] change export --- src/SciMLBase.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 6b0537396b..f8e4c903d8 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -822,7 +822,7 @@ export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, D DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction, IncrementingODEFunction, NonlinearFunction, HomotopyNonlinearFunction, IntervalNonlinearFunction, BVPFunction, - DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction, ControlFunction + DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction, ODEInputFunction export OptimizationFunction, MultiObjectiveOptimizationFunction