Skip to content

Commit 0abd825

Browse files
committed
init: add cost and coalesce
1 parent d8f0e2a commit 0abd825

File tree

1 file changed

+54
-9
lines changed

1 file changed

+54
-9
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ struct ODESystem <: AbstractODESystem
5151
observed::Vector{Equation}
5252
"""System of constraints that must be satisfied by the solution to the system."""
5353
constraintsystem::Union{Nothing, ConstraintsSystem}
54+
"""A set of expressions defining the costs of the system for optimal control."""
55+
costs::Vector
56+
"""Takes the cost vector and returns a scalar for optimization."""
57+
coalesce::Function
5458
"""
5559
Time-derivative matrix. Note: this field will not be defined until
5660
[`calculate_tgrad`](@ref) is called on the system.
@@ -338,7 +342,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
338342
metadata, gui_metadata, is_dde, tstops, checks = checks)
339343
end
340344

341-
function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
345+
function ODESystem(eqs, iv; constraints = Equation[], costs = Equation[], kwargs...)
342346
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
343347

344348
for eq in get(kwargs, :parameter_dependencies, Equation[])
@@ -384,6 +388,13 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
384388
end
385389
end
386390

391+
if !isempty(costs)
392+
coststs, costps = process_costs(costs, allunknowns, new_ps, iv)
393+
for p in costps
394+
!in(p, new_ps) && push!(new_ps, p)
395+
end
396+
end
397+
387398
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
388399
collect(new_ps); constraintsystem, kwargs...)
389400
end
@@ -733,22 +744,52 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
733744
return nothing
734745
end
735746

736-
# Validate that all the variables in the BVP constraints are well-formed states or parameters.
737-
# - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
738-
# - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
747+
"""
748+
Build the constraint system for the ODESystem.
749+
"""
739750
function process_constraint_system(
740751
constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
741752
isempty(constraints) && return nothing
742753

743754
constraintsts = OrderedSet()
744755
constraintps = OrderedSet()
745-
746756
for cons in constraints
747757
collect_vars!(constraintsts, constraintps, cons, iv)
748758
end
749759

750760
# Validate the states.
751-
for var in constraintsts
761+
validate_vars_and_find_ps!(coststs, costps, sts, iv)
762+
763+
ConstraintsSystem(
764+
constraints, collect(constraintsts), collect(constraintps); name = consname)
765+
end
766+
767+
"""
768+
Process the costs for the constraint system.
769+
"""
770+
function process_costs(costs::Vector{Equation}, sts, ps, iv)
771+
coststs = OrderedSet()
772+
costps = OrderedSet()
773+
for cost in costs
774+
collect_vars!(coststs, costps, cost, iv)
775+
end
776+
777+
validate_vars_and_find_ps!(coststs, costps, sts, iv)
778+
end
779+
780+
"""
781+
Validate that all the variables in an auxiliary system of the ODESystem (constraint or costs) are
782+
well-formed states or parameters.
783+
- Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
784+
- Callable/delay parameters should be parameters of the system
785+
786+
Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 then p should be added as a
787+
parameter of the system.
788+
"""
789+
function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
790+
sts = sysvars
791+
792+
for var in auxvars
752793
if !iscall(var)
753794
occursin(iv, var) && (var sts ||
754795
throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
@@ -763,13 +804,17 @@ function process_constraint_system(
763804
arg isa AbstractFloat ||
764805
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
765806

766-
isparameter(arg) && push!(constraintps, arg)
807+
isparameter(arg) && push!(auxps, arg)
767808
else
768809
var sts &&
769810
@warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
770811
end
771812
end
813+
end
772814

773-
ConstraintsSystem(
774-
constraints, collect(constraintsts), collect(constraintps); name = consname)
815+
function generate_cost_function(sys::ODESystem)
816+
costs = get_costs(sys)
817+
coalesce = get_coalesce(sys)
818+
cost_fn = build_function_wrapper()
819+
return (u, p, t) -> coalesce(cost_fn(u, p, t))
775820
end

0 commit comments

Comments
 (0)