Skip to content

Commit 2ba3da7

Browse files
committed
fix: fix tests, add parameter test
1 parent d69dae7 commit 2ba3da7

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ for prop in [:eqs
922922
:is_scalar_noise
923923
:isscheduled
924924
:costs
925-
:coalesce]
925+
:consolidate]
926926
fname_get = Symbol(:get_, prop)
927927
fname_has = Symbol(:has_, prop)
928928
@eval begin

src/systems/diffeqs/odesystem.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct ODESystem <: AbstractODESystem
5454
"""A set of expressions defining the costs of the system for optimal control."""
5555
costs::Vector
5656
"""Takes the cost vector and returns a scalar for optimization."""
57-
coalesce::Union{Nothing, Function}
57+
consolidate::Union{Nothing, Function}
5858
"""
5959
Time-derivative matrix. Note: this field will not be defined until
6060
[`calculate_tgrad`](@ref) is called on the system.
@@ -210,7 +210,7 @@ struct ODESystem <: AbstractODESystem
210210

211211
function ODESystem(
212212
tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls,
213-
observed, constraints, costs, coalesce, tgrad,
213+
observed, constraints, costs, consolidate, tgrad,
214214
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
215215
torn_matching, initializesystem, initialization_eqs, schedule,
216216
connector_type, preface, cevents,
@@ -234,7 +234,7 @@ struct ODESystem <: AbstractODESystem
234234
check_units(u, deqs)
235235
end
236236
new(tag, deqs, iv, dvs, ps, tspan, var_to_name,
237-
ctrls, observed, constraints, costs, coalesce, tgrad, jac,
237+
ctrls, observed, constraints, costs, consolidate, tgrad, jac,
238238
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
239239
initializesystem, initialization_eqs, schedule, connector_type, preface,
240240
cevents, devents, parameter_dependencies, assertions, metadata,
@@ -249,7 +249,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
249249
observed = Equation[],
250250
constraintsystem = nothing,
251251
costs = Num[],
252-
coalesce = nothing,
252+
consolidate = nothing,
253253
systems = ODESystem[],
254254
tspan = nothing,
255255
name = nothing,
@@ -334,15 +334,15 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
334334
end
335335
costs = wrap.(costs)
336336

337-
if length(costs) > 1 && isnothing(coalesce)
338-
error("Must specify a coalesce function for the costs vector.")
337+
if length(costs) > 1 && isnothing(consolidate)
338+
error("Must specify a consolidation function for the costs vector.")
339339
end
340340

341341
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
342342

343343
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
344344
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed,
345-
constraintsystem, costs, coalesce, tgrad, jac,
345+
constraintsystem, costs, consolidate, tgrad, jac,
346346
ctrl_jac, Wfact, Wfact_t, name, description, systems,
347347
defaults, guesses, nothing, initializesystem,
348348
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
@@ -421,7 +421,7 @@ function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
421421
_eq_unordered(continuous_events(sys1), continuous_events(sys2)) &&
422422
_eq_unordered(discrete_events(sys1), discrete_events(sys2)) &&
423423
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2))) &&
424-
isequal(get_constraintsystem(sys1), get_constraintssystem(sys2)) &&
424+
isequal(get_constraintsystem(sys1), get_constraintsystem(sys2)) &&
425425
_eq_unordered(get_costs(sys1), get_costs(sys2))
426426
end
427427

@@ -770,7 +770,7 @@ function process_constraint_system(
770770
end
771771

772772
# Validate the states.
773-
validate_vars_and_find_ps!(coststs, costps, sts, iv)
773+
validate_vars_and_find_ps!(constraintsts, constraintps, sts, iv)
774774

775775
ConstraintsSystem(
776776
constraints, collect(constraintsts), collect(constraintps); name = consname)
@@ -830,7 +830,7 @@ Generate a function that takes a solution object and computes the cost function
830830
"""
831831
function generate_cost_function(sys::ODESystem, kwargs...)
832832
costs = get_costs(sys)
833-
coalesce = get_coalesce(sys)
833+
consolidate = get_consolidate(sys)
834834
iv = get_iv(sys)
835835

836836
ps = parameters(sys; initial_parameters = false)
@@ -853,6 +853,6 @@ function generate_cost_function(sys::ODESystem, kwargs...)
853853
fs = build_function_wrapper(sys, costs, sol, _p..., t; output_type = Array, kwargs...)
854854
vc_oop, vc_iip = eval_or_rgf.(fs)
855855

856-
cost(sol, p, t) = coalesce(vc_oop(sol, p, t))
856+
cost(sol, p, t) = consolidate(vc_oop(sol, p, t))
857857
return cost
858858
end

test/optimal_control.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions
2-
32
using OrdinaryDiffEq
43
using BoundaryValueDiffEqMIRK, BoundaryValueDiffEqAscher
54
using BenchmarkTools
@@ -278,6 +277,7 @@ end
278277
@testset "Cost function compilation" begin
279278
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
280279
@variables x(..) y(..)
280+
t = ModelingToolkit.t_nounits
281281

282282
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
283283
D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
@@ -287,13 +287,24 @@ end
287287
parammap ==> 7.5, β => 4, γ => 8.0, δ => 5.0]
288288
costs = [x(0.6), x(0.3)^2]
289289
consolidate(u) = (u[1] + 3)^2 + u[2]
290-
@mtkbuild lksys = ODESystem(eqs, t; costs, coalesce = consolidate)
290+
@mtkbuild lksys = ODESystem(eqs, t; costs, consolidate)
291291
@test_throws ErrorException @mtkbuild lksys2 = ODESystem(eqs, t; costs)
292292

293293
prob = ODEProblem(lksys, u0map, tspan, parammap)
294294
sol = solve(prob, Tsit5())
295295
costfn = ModelingToolkit.generate_cost_function(lksys)
296-
p = prob.p
297-
t = tspan[2]
298-
@test costfn(sol, p, t) (sol(0.6)[1] + 3)^2 + sol(0.3)[1]^2
296+
_t = tspan[2]
297+
@test costfn(sol, prob.p, _t) (sol(0.6)[1] + 3)^2 + sol(0.3)[1]^2
298+
299+
### With a parameter
300+
@parameters t_c
301+
costs = [y(t_c) + x(0.0), x(0.4)^2]
302+
consolidate(u) = log(u[1]) - u[2]
303+
@mtkbuild lksys = ODESystem(eqs, t; costs, consolidate)
304+
@test t_c Set(parameters(lksys))
305+
push!(parammap, t_c => 0.56)
306+
prob = ODEProblem(lksys, u0map, tspan, parammap)
307+
sol = solve(prob, Tsit5())
308+
costfn = ModelingToolkit.generate_cost_function(lksys)
309+
@test costfn(sol, prob.p, _t) log(sol(0.56)[2] + sol(0.)[1]) - sol(0.4)[1]^2
299310
end

0 commit comments

Comments
 (0)