Skip to content

Commit 246464c

Browse files
committed
feat: add cost and coalesce to ODESystem
1 parent 64fb10e commit 246464c

File tree

4 files changed

+78
-18
lines changed

4 files changed

+78
-18
lines changed

src/systems/abstractsystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,9 @@ for prop in [:eqs
920920
:tstops
921921
:index_cache
922922
:is_scalar_noise
923-
:isscheduled]
923+
:isscheduled
924+
:costs
925+
:coalesce]
924926
fname_get = Symbol(:get_, prop)
925927
fname_has = Symbol(:has_, prop)
926928
@eval begin

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
924924
exprs = vcat(init_conds, cons)
925925
_p = reorder_parameters(sys, ps)
926926

927-
build_function_wrapper(sys, exprs, sol, _p..., t; output_type = Array, kwargs...)
927+
build_function_wrapper(sys, exprs, sol, _p..., iv; output_type = Array, kwargs...)
928928
end
929929

930930
"""

src/systems/diffeqs/odesystem.jl

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct ODESystem <: AbstractODESystem
5454
"""A set of expressions defining the costs of the system for optimal control."""
5555
costs::Vector
5656
"""Takes the cost vector and returns a scalar for optimization."""
57-
coalesce::Function
57+
coalesce::Union{Nothing, Function}
5858
"""
5959
Time-derivative matrix. Note: this field will not be defined until
6060
[`calculate_tgrad`](@ref) is called on the system.
@@ -209,7 +209,7 @@ struct ODESystem <: AbstractODESystem
209209
parent::Any
210210

211211
function ODESystem(
212-
tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
212+
tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, costs, coalesce, tgrad,
213213
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
214214
torn_matching, initializesystem, initialization_eqs, schedule,
215215
connector_type, preface, cevents,
@@ -233,7 +233,7 @@ struct ODESystem <: AbstractODESystem
233233
check_units(u, deqs)
234234
end
235235
new(tag, deqs, iv, dvs, ps, tspan, var_to_name,
236-
ctrls, observed, constraints, tgrad, jac,
236+
ctrls, observed, constraints, costs, coalesce, tgrad, jac,
237237
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
238238
initializesystem, initialization_eqs, schedule, connector_type, preface,
239239
cevents, devents, parameter_dependencies, assertions, metadata,
@@ -247,6 +247,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
247247
controls = Num[],
248248
observed = Equation[],
249249
constraintsystem = nothing,
250+
costs = Num[],
251+
coalesce = nothing,
250252
systems = ODESystem[],
251253
tspan = nothing,
252254
name = nothing,
@@ -327,22 +329,26 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
327329
cons = get_constraintsystem(sys)
328330
cons !== nothing && push!(conssystems, cons)
329331
end
330-
@show conssystems
331332
@set! constraintsystem.systems = conssystems
332333
end
334+
costs = wrap.(costs)
335+
336+
if length(costs) > 1 && isnothing(coalesce)
337+
error("Must specify a coalesce function for the costs vector.")
338+
end
333339

334340
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
335341

336342
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
337-
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
343+
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, costs, coalesce, tgrad, jac,
338344
ctrl_jac, Wfact, Wfact_t, name, description, systems,
339345
defaults, guesses, nothing, initializesystem,
340346
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
341347
disc_callbacks, parameter_dependencies, assertions,
342348
metadata, gui_metadata, is_dde, tstops, checks = checks)
343349
end
344350

345-
function ODESystem(eqs, iv; constraints = Equation[], costs = Equation[], kwargs...)
351+
function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
346352
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
347353

348354
for eq in get(kwargs, :parameter_dependencies, Equation[])
@@ -394,9 +400,10 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Equation[], kwargs
394400
!in(p, new_ps) && push!(new_ps, p)
395401
end
396402
end
403+
costs = wrap.(costs)
397404

398405
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
399-
collect(new_ps); constraintsystem, kwargs...)
406+
collect(new_ps); constraintsystem, costs, kwargs...)
400407
end
401408

402409
# NOTE: equality does not check cached Jacobian
@@ -411,7 +418,9 @@ function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
411418
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
412419
_eq_unordered(continuous_events(sys1), continuous_events(sys2)) &&
413420
_eq_unordered(discrete_events(sys1), discrete_events(sys2)) &&
414-
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
421+
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2))) &&
422+
isequal(get_constraintsystem(sys1), get_constraintssystem(sys2)) &&
423+
_eq_unordered(get_costs(sys1), get_costs(sys2))
415424
end
416425

417426
function flatten(sys::ODESystem, noeqs = false)
@@ -768,14 +777,15 @@ end
768777
"""
769778
Process the costs for the constraint system.
770779
"""
771-
function process_costs(costs::Vector{Equation}, sts, ps, iv)
780+
function process_costs(costs::Vector, sts, ps, iv)
772781
coststs = OrderedSet()
773782
costps = OrderedSet()
774783
for cost in costs
775784
collect_vars!(coststs, costps, cost, iv)
776785
end
777786

778787
validate_vars_and_find_ps!(coststs, costps, sts, iv)
788+
coststs, costps
779789
end
780790

781791
"""
@@ -813,9 +823,34 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
813823
end
814824
end
815825

816-
function generate_cost_function(sys::ODESystem)
826+
"""
827+
Generate a function that takes a solution object and computes the cost function obtained by coalescing the costs vector.
828+
"""
829+
function generate_cost_function(sys::ODESystem, kwargs...)
817830
costs = get_costs(sys)
818831
coalesce = get_coalesce(sys)
819-
cost_fn = build_function_wrapper()
820-
return (u, p, t) -> coalesce(cost_fn(u, p, t))
832+
iv = get_iv(sys)
833+
834+
ps = parameters(sys; initial_parameters = false)
835+
sts = unknowns(sys)
836+
np = length(ps)
837+
ns = length(sts)
838+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
839+
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
840+
841+
@variables sol(..)[1:ns]
842+
for st in vars(costs)
843+
x = operation(st)
844+
t = only(arguments(st))
845+
idx = stidxmap[x(iv)]
846+
847+
costs = map(c -> Symbolics.fast_substitute(c, Dict(x(t) => sol(t)[idx])), costs)
848+
end
849+
850+
_p = reorder_parameters(sys, ps)
851+
fs = build_function_wrapper(sys, costs, sol, _p..., t; output_type = Array, kwargs...)
852+
vc_oop, vc_iip = eval_or_rgf.(fs)
853+
854+
cost(sol, p, t) = coalesce(vc_oop(sol, p, t))
855+
return cost
821856
end

test/bvproblem.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
1111
solvers = [MIRK4]
1212
daesolvers = [Ascher2, Ascher4, Ascher6]
1313

14-
let
14+
@testset "Lotka-Volterra" begin
1515
@parameters α=7.5 β=4.0 γ=8.0 δ=5.0
1616
@variables x(t)=1.0 y(t)=2.0
1717

@@ -47,7 +47,7 @@ let
4747
end
4848

4949
### Testing on pendulum
50-
let
50+
@testset "Pendulum" begin
5151
@parameters g=9.81 L=1.0
5252
@variables θ(t)=π / 2 θ_t(t)
5353

@@ -86,7 +86,7 @@ end
8686
##################################################################
8787

8888
# Test generation of boundary condition function using `generate_function_bc`. Compare solutions to manually written boundary conditions
89-
let
89+
@testset "Boundary Condition Compilation" begin
9090
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
9191
@variables x(..) y(..)
9292
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
@@ -168,7 +168,7 @@ function test_solvers(
168168
end
169169

170170
# Simple ODESystem with BVP constraints.
171-
let
171+
@testset "ODE with constraints" begin
172172
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
173173
@variables x(..) y(..)
174174

@@ -274,3 +274,26 @@ end
274274
# bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, check_length = false)
275275
# test_solvers(daesolvers, bvp, u0map, constr, get_alg_eqs(pend))
276276
# end
277+
278+
@testset "Cost function compilation" begin
279+
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
280+
@variables x(..) y(..)
281+
282+
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
283+
D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
284+
285+
tspan = (0.0, 1.0)
286+
u0map = [x(t) => 4.0, y(t) => 2.0]
287+
parammap ==> 7.5, β => 4, γ => 8.0, δ => 5.0]
288+
costs = [x(0.6), x(0.3)^2]
289+
consolidate(u) = (u[1] + 3)^2 + u[2]
290+
@mtkbuild lksys = ODESystem(eqs, t; costs, coalesce = consolidate)
291+
@test_throws ErrorException @mtkbuild lksys2 = ODESystem(eqs, t; costs)
292+
293+
prob = ODEProblem(lksys, u0map, tspan, parammap)
294+
sol = solve(prob, Tsit5())
295+
costfn = ModelingToolkit.generate_cost_function(lksys)
296+
p = prob.p
297+
t = tspan[2]
298+
@test costfn(sol, p, t) (sol(0.6)[1] + 3)^2 + sol(0.3)[1]^2
299+
end

0 commit comments

Comments
 (0)