@@ -54,7 +54,7 @@ struct ODESystem <: AbstractODESystem
5454 """ A set of expressions defining the costs of the system for optimal control."""
5555 costs:: Vector
5656 """ Takes the cost vector and returns a scalar for optimization."""
57- coalesce:: Function
57+ coalesce:: Union{Nothing, Function}
5858 """
5959 Time-derivative matrix. Note: this field will not be defined until
6060 [`calculate_tgrad`](@ref) is called on the system.
@@ -209,7 +209,7 @@ struct ODESystem <: AbstractODESystem
209209 parent:: Any
210210
211211 function ODESystem (
212- tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
212+ tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, costs, coalesce, tgrad,
213213 jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
214214 torn_matching, initializesystem, initialization_eqs, schedule,
215215 connector_type, preface, cevents,
@@ -233,7 +233,7 @@ struct ODESystem <: AbstractODESystem
233233 check_units (u, deqs)
234234 end
235235 new (tag, deqs, iv, dvs, ps, tspan, var_to_name,
236- ctrls, observed, constraints, tgrad, jac,
236+ ctrls, observed, constraints, costs, coalesce, tgrad, jac,
237237 ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
238238 initializesystem, initialization_eqs, schedule, connector_type, preface,
239239 cevents, devents, parameter_dependencies, assertions, metadata,
@@ -247,6 +247,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
247247 controls = Num[],
248248 observed = Equation[],
249249 constraintsystem = nothing ,
250+ costs = Num[],
251+ coalesce = nothing ,
250252 systems = ODESystem[],
251253 tspan = nothing ,
252254 name = nothing ,
@@ -327,22 +329,26 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
327329 cons = get_constraintsystem (sys)
328330 cons != = nothing && push! (conssystems, cons)
329331 end
330- @show conssystems
331332 @set! constraintsystem. systems = conssystems
332333 end
334+ costs = wrap .(costs)
335+
336+ if length (costs) > 1 && isnothing (coalesce)
337+ error (" Must specify a coalesce function for the costs vector." )
338+ end
333339
334340 assertions = Dict {BasicSymbolic, Any} (unwrap (k) => v for (k, v) in assertions)
335341
336342 ODESystem (Threads. atomic_add! (SYSTEM_COUNT, UInt (1 )),
337- deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
343+ deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, costs, coalesce, tgrad, jac,
338344 ctrl_jac, Wfact, Wfact_t, name, description, systems,
339345 defaults, guesses, nothing , initializesystem,
340346 initialization_eqs, schedule, connector_type, preface, cont_callbacks,
341347 disc_callbacks, parameter_dependencies, assertions,
342348 metadata, gui_metadata, is_dde, tstops, checks = checks)
343349end
344350
345- function ODESystem (eqs, iv; constraints = Equation[], costs = Equation [], kwargs... )
351+ function ODESystem (eqs, iv; constraints = Equation[], costs = Num [], kwargs... )
346352 diffvars, allunknowns, ps, eqs = process_equations (eqs, iv)
347353
348354 for eq in get (kwargs, :parameter_dependencies , Equation[])
@@ -394,9 +400,10 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Equation[], kwargs
394400 ! in (p, new_ps) && push! (new_ps, p)
395401 end
396402 end
403+ costs = wrap .(costs)
397404
398405 return ODESystem (eqs, iv, collect (Iterators. flatten ((diffvars, algevars, consvars))),
399- collect (new_ps); constraintsystem, kwargs... )
406+ collect (new_ps); constraintsystem, costs , kwargs... )
400407end
401408
402409# NOTE: equality does not check cached Jacobian
@@ -411,7 +418,9 @@ function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
411418 _eq_unordered (get_ps (sys1), get_ps (sys2)) &&
412419 _eq_unordered (continuous_events (sys1), continuous_events (sys2)) &&
413420 _eq_unordered (discrete_events (sys1), discrete_events (sys2)) &&
414- all (s1 == s2 for (s1, s2) in zip (get_systems (sys1), get_systems (sys2)))
421+ all (s1 == s2 for (s1, s2) in zip (get_systems (sys1), get_systems (sys2))) &&
422+ isequal (get_constraintsystem (sys1), get_constraintssystem (sys2)) &&
423+ _eq_unordered (get_costs (sys1), get_costs (sys2))
415424end
416425
417426function flatten (sys:: ODESystem , noeqs = false )
@@ -767,14 +776,15 @@ end
767776"""
768777Process the costs for the constraint system.
769778"""
770- function process_costs (costs:: Vector{Equation} , sts, ps, iv)
779+ function process_costs (costs:: Vector , sts, ps, iv)
771780 coststs = OrderedSet ()
772781 costps = OrderedSet ()
773782 for cost in costs
774783 collect_vars! (coststs, costps, cost, iv)
775784 end
776785
777786 validate_vars_and_find_ps! (coststs, costps, sts, iv)
787+ coststs, costps
778788end
779789
780790"""
@@ -812,9 +822,34 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
812822 end
813823end
814824
815- function generate_cost_function (sys:: ODESystem )
825+ """
826+ Generate a function that takes a solution object and computes the cost function obtained by coalescing the costs vector.
827+ """
828+ function generate_cost_function (sys:: ODESystem , kwargs... )
816829 costs = get_costs (sys)
817830 coalesce = get_coalesce (sys)
818- cost_fn = build_function_wrapper ()
819- return (u, p, t) -> coalesce (cost_fn (u, p, t))
831+ iv = get_iv (sys)
832+
833+ ps = parameters (sys; initial_parameters = false )
834+ sts = unknowns (sys)
835+ np = length (ps)
836+ ns = length (sts)
837+ stidxmap = Dict ([v => i for (i, v) in enumerate (sts)])
838+ pidxmap = Dict ([v => i for (i, v) in enumerate (ps)])
839+
840+ @variables sol (.. )[1 : ns]
841+ for st in vars (costs)
842+ x = operation (st)
843+ t = only (arguments (st))
844+ idx = stidxmap[x (iv)]
845+
846+ costs = map (c -> Symbolics. fast_substitute (c, Dict (x (t) => sol (t)[idx])), costs)
847+ end
848+
849+ _p = reorder_parameters (sys, ps)
850+ fs = build_function_wrapper (sys, costs, sol, _p... , t; output_type = Array, kwargs... )
851+ vc_oop, vc_iip = eval_or_rgf .(fs)
852+
853+ cost (sol, p, t) = coalesce (vc_oop (sol, p, t))
854+ return cost
820855end
0 commit comments