Skip to content

Commit 64fb10e

Browse files
committed
init: add cost and coalesce
1 parent 60e202e commit 64fb10e

File tree

1 file changed

+54
-9
lines changed

1 file changed

+54
-9
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ struct ODESystem <: AbstractODESystem
5151
observed::Vector{Equation}
5252
"""System of constraints that must be satisfied by the solution to the system."""
5353
constraintsystem::Union{Nothing, ConstraintsSystem}
54+
"""A set of expressions defining the costs of the system for optimal control."""
55+
costs::Vector
56+
"""Takes the cost vector and returns a scalar for optimization."""
57+
coalesce::Function
5458
"""
5559
Time-derivative matrix. Note: this field will not be defined until
5660
[`calculate_tgrad`](@ref) is called on the system.
@@ -338,7 +342,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
338342
metadata, gui_metadata, is_dde, tstops, checks = checks)
339343
end
340344

341-
function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
345+
function ODESystem(eqs, iv; constraints = Equation[], costs = Equation[], kwargs...)
342346
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
343347

344348
for eq in get(kwargs, :parameter_dependencies, Equation[])
@@ -384,6 +388,13 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
384388
end
385389
end
386390

391+
if !isempty(costs)
392+
coststs, costps = process_costs(costs, allunknowns, new_ps, iv)
393+
for p in costps
394+
!in(p, new_ps) && push!(new_ps, p)
395+
end
396+
end
397+
387398
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
388399
collect(new_ps); constraintsystem, kwargs...)
389400
end
@@ -734,22 +745,52 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
734745
return nothing
735746
end
736747

737-
# Validate that all the variables in the BVP constraints are well-formed states or parameters.
738-
# - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
739-
# - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
748+
"""
749+
Build the constraint system for the ODESystem.
750+
"""
740751
function process_constraint_system(
741752
constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
742753
isempty(constraints) && return nothing
743754

744755
constraintsts = OrderedSet()
745756
constraintps = OrderedSet()
746-
747757
for cons in constraints
748758
collect_vars!(constraintsts, constraintps, cons, iv)
749759
end
750760

751761
# Validate the states.
752-
for var in constraintsts
762+
validate_vars_and_find_ps!(coststs, costps, sts, iv)
763+
764+
ConstraintsSystem(
765+
constraints, collect(constraintsts), collect(constraintps); name = consname)
766+
end
767+
768+
"""
769+
Process the costs for the constraint system.
770+
"""
771+
function process_costs(costs::Vector{Equation}, sts, ps, iv)
772+
coststs = OrderedSet()
773+
costps = OrderedSet()
774+
for cost in costs
775+
collect_vars!(coststs, costps, cost, iv)
776+
end
777+
778+
validate_vars_and_find_ps!(coststs, costps, sts, iv)
779+
end
780+
781+
"""
782+
Validate that all the variables in an auxiliary system of the ODESystem (constraint or costs) are
783+
well-formed states or parameters.
784+
- Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
785+
- Callable/delay parameters should be parameters of the system
786+
787+
Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 then p should be added as a
788+
parameter of the system.
789+
"""
790+
function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
791+
sts = sysvars
792+
793+
for var in auxvars
753794
if !iscall(var)
754795
occursin(iv, var) && (var sts ||
755796
throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
@@ -764,13 +805,17 @@ function process_constraint_system(
764805
arg isa AbstractFloat ||
765806
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
766807

767-
isparameter(arg) && push!(constraintps, arg)
808+
isparameter(arg) && push!(auxps, arg)
768809
else
769810
var sts &&
770811
@warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
771812
end
772813
end
814+
end
773815

774-
ConstraintsSystem(
775-
constraints, collect(constraintsts), collect(constraintps); name = consname)
816+
function generate_cost_function(sys::ODESystem)
817+
costs = get_costs(sys)
818+
coalesce = get_coalesce(sys)
819+
cost_fn = build_function_wrapper()
820+
return (u, p, t) -> coalesce(cost_fn(u, p, t))
776821
end

0 commit comments

Comments
 (0)