@@ -51,6 +51,10 @@ struct ODESystem <: AbstractODESystem
5151 observed:: Vector{Equation}
5252 """ System of constraints that must be satisfied by the solution to the system."""
5353 constraintsystem:: Union{Nothing, ConstraintsSystem}
54+ """ A set of expressions defining the costs of the system for optimal control."""
55+ costs:: Vector
56+ """ Takes the cost vector and returns a scalar for optimization."""
57+ coalesce:: Function
5458 """
5559 Time-derivative matrix. Note: this field will not be defined until
5660 [`calculate_tgrad`](@ref) is called on the system.
@@ -338,7 +342,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
338342 metadata, gui_metadata, is_dde, tstops, checks = checks)
339343end
340344
341- function ODESystem (eqs, iv; constraints = Equation[], kwargs... )
345+ function ODESystem (eqs, iv; constraints = Equation[], costs = Equation[], kwargs... )
342346 diffvars, allunknowns, ps, eqs = process_equations (eqs, iv)
343347
344348 for eq in get (kwargs, :parameter_dependencies , Equation[])
@@ -384,6 +388,13 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
384388 end
385389 end
386390
391+ if ! isempty (costs)
392+ coststs, costps = process_costs (costs, allunknowns, new_ps, iv)
393+ for p in costps
394+ ! in (p, new_ps) && push! (new_ps, p)
395+ end
396+ end
397+
387398 return ODESystem (eqs, iv, collect (Iterators. flatten ((diffvars, algevars, consvars))),
388399 collect (new_ps); constraintsystem, kwargs... )
389400end
@@ -733,22 +744,52 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
733744 return nothing
734745end
735746
736- # Validate that all the variables in the BVP constraints are well-formed states or parameters.
737- # - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
738- # - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
747+ """
748+ Build the constraint system for the ODESystem.
749+ """
739750function process_constraint_system (
740751 constraints:: Vector{Equation} , sts, ps, iv; consname = :cons )
741752 isempty (constraints) && return nothing
742753
743754 constraintsts = OrderedSet ()
744755 constraintps = OrderedSet ()
745-
746756 for cons in constraints
747757 collect_vars! (constraintsts, constraintps, cons, iv)
748758 end
749759
750760 # Validate the states.
751- for var in constraintsts
761+ validate_vars_and_find_ps! (coststs, costps, sts, iv)
762+
763+ ConstraintsSystem (
764+ constraints, collect (constraintsts), collect (constraintps); name = consname)
765+ end
766+
767+ """
768+ Process the costs for the constraint system.
769+ """
770+ function process_costs (costs:: Vector{Equation} , sts, ps, iv)
771+ coststs = OrderedSet ()
772+ costps = OrderedSet ()
773+ for cost in costs
774+ collect_vars! (coststs, costps, cost, iv)
775+ end
776+
777+ validate_vars_and_find_ps! (coststs, costps, sts, iv)
778+ end
779+
780+ """
781+ Validate that all the variables in an auxiliary system of the ODESystem (constraint or costs) are
782+ well-formed states or parameters.
783+ - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
784+ - Callable/delay parameters should be parameters of the system
785+
786+ Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 then p should be added as a
787+ parameter of the system.
788+ """
789+ function validate_vars_and_find_ps! (auxvars, auxps, sysvars, iv)
790+ sts = sysvars
791+
792+ for var in auxvars
752793 if ! iscall (var)
753794 occursin (iv, var) && (var ∈ sts ||
754795 throw (ArgumentError (" Time-dependent variable $var is not an unknown of the system." )))
@@ -763,13 +804,17 @@ function process_constraint_system(
763804 arg isa AbstractFloat ||
764805 throw (ArgumentError (" Invalid argument specified for variable $var . The argument of the variable should be either $iv , a parameter, or a value specifying the time that the constraint holds." ))
765806
766- isparameter (arg) && push! (constraintps , arg)
807+ isparameter (arg) && push! (auxps , arg)
767808 else
768809 var ∈ sts &&
769810 @warn " Variable $var has no argument. It will be interpreted as $var ($iv ), and the constraint will apply to the entire interval."
770811 end
771812 end
813+ end
772814
773- ConstraintsSystem (
774- constraints, collect (constraintsts), collect (constraintps); name = consname)
815+ function generate_cost_function (sys:: ODESystem )
816+ costs = get_costs (sys)
817+ coalesce = get_coalesce (sys)
818+ cost_fn = build_function_wrapper ()
819+ return (u, p, t) -> coalesce (cost_fn (u, p, t))
775820end
0 commit comments