@@ -54,7 +54,7 @@ struct ODESystem <: AbstractODESystem
54
54
""" A set of expressions defining the costs of the system for optimal control."""
55
55
costs:: Vector
56
56
""" Takes the cost vector and returns a scalar for optimization."""
57
- coalesce:: Function
57
+ coalesce:: Union{Nothing, Function}
58
58
"""
59
59
Time-derivative matrix. Note: this field will not be defined until
60
60
[`calculate_tgrad`](@ref) is called on the system.
@@ -209,7 +209,7 @@ struct ODESystem <: AbstractODESystem
209
209
parent:: Any
210
210
211
211
function ODESystem (
212
- tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
212
+ tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, costs, coalesce, tgrad,
213
213
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
214
214
torn_matching, initializesystem, initialization_eqs, schedule,
215
215
connector_type, preface, cevents,
@@ -233,7 +233,7 @@ struct ODESystem <: AbstractODESystem
233
233
check_units (u, deqs)
234
234
end
235
235
new (tag, deqs, iv, dvs, ps, tspan, var_to_name,
236
- ctrls, observed, constraints, tgrad, jac,
236
+ ctrls, observed, constraints, costs, coalesce, tgrad, jac,
237
237
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
238
238
initializesystem, initialization_eqs, schedule, connector_type, preface,
239
239
cevents, devents, parameter_dependencies, assertions, metadata,
@@ -247,6 +247,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
247
247
controls = Num[],
248
248
observed = Equation[],
249
249
constraintsystem = nothing ,
250
+ costs = Num[],
251
+ coalesce = nothing ,
250
252
systems = ODESystem[],
251
253
tspan = nothing ,
252
254
name = nothing ,
@@ -327,22 +329,26 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
327
329
cons = get_constraintsystem (sys)
328
330
cons != = nothing && push! (conssystems, cons)
329
331
end
330
- @show conssystems
331
332
@set! constraintsystem. systems = conssystems
332
333
end
334
+ costs = wrap .(costs)
335
+
336
+ if length (costs) > 1 && isnothing (coalesce)
337
+ error (" Must specify a coalesce function for the costs vector." )
338
+ end
333
339
334
340
assertions = Dict {BasicSymbolic, Any} (unwrap (k) => v for (k, v) in assertions)
335
341
336
342
ODESystem (Threads. atomic_add! (SYSTEM_COUNT, UInt (1 )),
337
- deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
343
+ deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, costs, coalesce, tgrad, jac,
338
344
ctrl_jac, Wfact, Wfact_t, name, description, systems,
339
345
defaults, guesses, nothing , initializesystem,
340
346
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
341
347
disc_callbacks, parameter_dependencies, assertions,
342
348
metadata, gui_metadata, is_dde, tstops, checks = checks)
343
349
end
344
350
345
- function ODESystem (eqs, iv; constraints = Equation[], costs = Equation [], kwargs... )
351
+ function ODESystem (eqs, iv; constraints = Equation[], costs = Num [], kwargs... )
346
352
diffvars, allunknowns, ps, eqs = process_equations (eqs, iv)
347
353
348
354
for eq in get (kwargs, :parameter_dependencies , Equation[])
@@ -394,9 +400,10 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Equation[], kwargs
394
400
! in (p, new_ps) && push! (new_ps, p)
395
401
end
396
402
end
403
+ costs = wrap .(costs)
397
404
398
405
return ODESystem (eqs, iv, collect (Iterators. flatten ((diffvars, algevars, consvars))),
399
- collect (new_ps); constraintsystem, kwargs... )
406
+ collect (new_ps); constraintsystem, costs , kwargs... )
400
407
end
401
408
402
409
# NOTE: equality does not check cached Jacobian
@@ -411,7 +418,9 @@ function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
411
418
_eq_unordered (get_ps (sys1), get_ps (sys2)) &&
412
419
_eq_unordered (continuous_events (sys1), continuous_events (sys2)) &&
413
420
_eq_unordered (discrete_events (sys1), discrete_events (sys2)) &&
414
- all (s1 == s2 for (s1, s2) in zip (get_systems (sys1), get_systems (sys2)))
421
+ all (s1 == s2 for (s1, s2) in zip (get_systems (sys1), get_systems (sys2))) &&
422
+ isequal (get_constraintsystem (sys1), get_constraintssystem (sys2)) &&
423
+ _eq_unordered (get_costs (sys1), get_costs (sys2))
415
424
end
416
425
417
426
function flatten (sys:: ODESystem , noeqs = false )
@@ -768,14 +777,15 @@ end
768
777
"""
769
778
Process the costs for the constraint system.
770
779
"""
771
- function process_costs (costs:: Vector{Equation} , sts, ps, iv)
780
+ function process_costs (costs:: Vector , sts, ps, iv)
772
781
coststs = OrderedSet ()
773
782
costps = OrderedSet ()
774
783
for cost in costs
775
784
collect_vars! (coststs, costps, cost, iv)
776
785
end
777
786
778
787
validate_vars_and_find_ps! (coststs, costps, sts, iv)
788
+ coststs, costps
779
789
end
780
790
781
791
"""
@@ -813,9 +823,34 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
813
823
end
814
824
end
815
825
816
- function generate_cost_function (sys:: ODESystem )
826
+ """
827
+ Generate a function that takes a solution object and computes the cost function obtained by coalescing the costs vector.
828
+ """
829
+ function generate_cost_function (sys:: ODESystem , kwargs... )
817
830
costs = get_costs (sys)
818
831
coalesce = get_coalesce (sys)
819
- cost_fn = build_function_wrapper ()
820
- return (u, p, t) -> coalesce (cost_fn (u, p, t))
832
+ iv = get_iv (sys)
833
+
834
+ ps = parameters (sys; initial_parameters = false )
835
+ sts = unknowns (sys)
836
+ np = length (ps)
837
+ ns = length (sts)
838
+ stidxmap = Dict ([v => i for (i, v) in enumerate (sts)])
839
+ pidxmap = Dict ([v => i for (i, v) in enumerate (ps)])
840
+
841
+ @variables sol (.. )[1 : ns]
842
+ for st in vars (costs)
843
+ x = operation (st)
844
+ t = only (arguments (st))
845
+ idx = stidxmap[x (iv)]
846
+
847
+ costs = map (c -> Symbolics. fast_substitute (c, Dict (x (t) => sol (t)[idx])), costs)
848
+ end
849
+
850
+ _p = reorder_parameters (sys, ps)
851
+ fs = build_function_wrapper (sys, costs, sol, _p... , t; output_type = Array, kwargs... )
852
+ vc_oop, vc_iip = eval_or_rgf .(fs)
853
+
854
+ cost (sol, p, t) = coalesce (vc_oop (sol, p, t))
855
+ return cost
821
856
end
0 commit comments