@@ -51,6 +51,10 @@ struct ODESystem <: AbstractODESystem
51
51
observed:: Vector{Equation}
52
52
""" System of constraints that must be satisfied by the solution to the system."""
53
53
constraintsystem:: Union{Nothing, ConstraintsSystem}
54
+ """ A set of expressions defining the costs of the system for optimal control."""
55
+ costs:: Vector
56
+ """ Takes the cost vector and returns a scalar for optimization."""
57
+ coalesce:: Function
54
58
"""
55
59
Time-derivative matrix. Note: this field will not be defined until
56
60
[`calculate_tgrad`](@ref) is called on the system.
@@ -338,7 +342,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
338
342
metadata, gui_metadata, is_dde, tstops, checks = checks)
339
343
end
340
344
341
- function ODESystem (eqs, iv; constraints = Equation[], kwargs... )
345
+ function ODESystem (eqs, iv; constraints = Equation[], costs = Equation[], kwargs... )
342
346
diffvars, allunknowns, ps, eqs = process_equations (eqs, iv)
343
347
344
348
for eq in get (kwargs, :parameter_dependencies , Equation[])
@@ -384,6 +388,13 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
384
388
end
385
389
end
386
390
391
+ if ! isempty (costs)
392
+ coststs, costps = process_costs (costs, allunknowns, new_ps, iv)
393
+ for p in costps
394
+ ! in (p, new_ps) && push! (new_ps, p)
395
+ end
396
+ end
397
+
387
398
return ODESystem (eqs, iv, collect (Iterators. flatten ((diffvars, algevars, consvars))),
388
399
collect (new_ps); constraintsystem, kwargs... )
389
400
end
@@ -734,22 +745,52 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
734
745
return nothing
735
746
end
736
747
737
- # Validate that all the variables in the BVP constraints are well-formed states or parameters.
738
- # - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
739
- # - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
748
+ """
749
+ Build the constraint system for the ODESystem.
750
+ """
740
751
function process_constraint_system (
741
752
constraints:: Vector{Equation} , sts, ps, iv; consname = :cons )
742
753
isempty (constraints) && return nothing
743
754
744
755
constraintsts = OrderedSet ()
745
756
constraintps = OrderedSet ()
746
-
747
757
for cons in constraints
748
758
collect_vars! (constraintsts, constraintps, cons, iv)
749
759
end
750
760
751
761
# Validate the states.
752
- for var in constraintsts
762
+ validate_vars_and_find_ps! (coststs, costps, sts, iv)
763
+
764
+ ConstraintsSystem (
765
+ constraints, collect (constraintsts), collect (constraintps); name = consname)
766
+ end
767
+
768
+ """
769
+ Process the costs for the constraint system.
770
+ """
771
+ function process_costs (costs:: Vector{Equation} , sts, ps, iv)
772
+ coststs = OrderedSet ()
773
+ costps = OrderedSet ()
774
+ for cost in costs
775
+ collect_vars! (coststs, costps, cost, iv)
776
+ end
777
+
778
+ validate_vars_and_find_ps! (coststs, costps, sts, iv)
779
+ end
780
+
781
+ """
782
+ Validate that all the variables in an auxiliary system of the ODESystem (constraint or costs) are
783
+ well-formed states or parameters.
784
+ - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
785
+ - Callable/delay parameters should be parameters of the system
786
+
787
+ Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 then p should be added as a
788
+ parameter of the system.
789
+ """
790
+ function validate_vars_and_find_ps! (auxvars, auxps, sysvars, iv)
791
+ sts = sysvars
792
+
793
+ for var in auxvars
753
794
if ! iscall (var)
754
795
occursin (iv, var) && (var ∈ sts ||
755
796
throw (ArgumentError (" Time-dependent variable $var is not an unknown of the system." )))
@@ -764,13 +805,17 @@ function process_constraint_system(
764
805
arg isa AbstractFloat ||
765
806
throw (ArgumentError (" Invalid argument specified for variable $var . The argument of the variable should be either $iv , a parameter, or a value specifying the time that the constraint holds." ))
766
807
767
- isparameter (arg) && push! (constraintps , arg)
808
+ isparameter (arg) && push! (auxps , arg)
768
809
else
769
810
var ∈ sts &&
770
811
@warn " Variable $var has no argument. It will be interpreted as $var ($iv ), and the constraint will apply to the entire interval."
771
812
end
772
813
end
814
+ end
773
815
774
- ConstraintsSystem (
775
- constraints, collect (constraintsts), collect (constraintps); name = consname)
816
+ function generate_cost_function (sys:: ODESystem )
817
+ costs = get_costs (sys)
818
+ coalesce = get_coalesce (sys)
819
+ cost_fn = build_function_wrapper ()
820
+ return (u, p, t) -> coalesce (cost_fn (u, p, t))
776
821
end
0 commit comments