Skip to content

Commit 6c5935a

Browse files
Merge pull request #562 from SciML/optimal_control
Add ControlSystem for Nonlinear Optimal Control
2 parents 95f14a3 + 432da8c commit 6c5935a

File tree

7 files changed

+214
-3
lines changed

7 files changed

+214
-3
lines changed

docs/src/systems/ControlSystem.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# ControlSystem
2+
3+
## System Constructors
4+
5+
```@docs
6+
ControlSystem
7+
```
8+
9+
## Composition and Accessor Functions
10+
11+
- `sys.eqs` or `equations(sys)`: The equations that define the system.
12+
- `sys.states` or `states(sys)`: The set of states in the system.
13+
- `sys.parameters` or `parameters(sys)`: The parameters of the system.
14+
- `sys.controls` or `controls(sys)`: The control variables of the system
15+
16+
## Transformations
17+
18+
```@docs
19+
runge_kutta_discretize
20+
```

src/ModelingToolkit.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ include("systems/nonlinear/nonlinearsystem.jl")
108108

109109
include("systems/optimization/optimizationsystem.jl")
110110

111+
include("systems/control/controlsystem.jl")
112+
111113
include("systems/pde/pdesystem.jl")
112114

113115
include("systems/reaction/reactionsystem.jl")
@@ -125,14 +127,16 @@ export OptimizationProblem, OptimizationProblemExpr
125127
export SteadyStateProblem, SteadyStateProblemExpr
126128
export JumpProblem, DiscreteProblem
127129
export NonlinearSystem, OptimizationSystem
130+
export ControlSystem
128131
export ode_order_lowering
132+
export runge_kutta_discretize
129133
export PDESystem
130134
export Reaction, ReactionSystem, ismassaction, oderatelaw, jumpratelaw
131135
export Differential, expand_derivatives, @derivatives
132136
export IntervalDomain, ProductDomain, , CircleDomain
133137
export Equation, ConstrainedEquation
134138
export Operation, Expression, Variable
135-
export independent_variable, states, parameters, equations, pins, observed
139+
export independent_variable, states, controls, parameters, equations, pins, observed
136140

137141
export calculate_jacobian, generate_jacobian, generate_function
138142
export calculate_tgrad, generate_tgrad

src/systems/abstractsystem.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ end
171171

172172
namespace_equations(sys::AbstractSystem) = namespace_equation.(equations(sys),sys.name,sys.iv.name)
173173

174-
175174
function namespace_equation(eq::Equation,name,ivname)
176175
_lhs = namespace_operation(eq.lhs,name,ivname)
177176
_rhs = namespace_operation(eq.rhs,name,ivname)

src/systems/control/controlsystem.jl

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
abstract type AbstractControlSystem <: AbstractSystem end
2+
3+
function namespace_controls(sys::AbstractSystem)
4+
[rename(x,renamespace(sys.name,x.name)) for x in controls(sys)]
5+
end
6+
7+
function controls(sys::AbstractControlSystem,args...)
8+
name = last(args)
9+
extra_names = reduce(Symbol,[Symbol(:₊,x.name) for x in args[1:end-1]])
10+
newname = renamespace(extra_names,name)
11+
rename(x,renamespace(sys.name,newname))(sys.iv())
12+
end
13+
14+
function controls(sys::AbstractControlSystem,name::Symbol)
15+
x = sys.controls[findfirst(x->x.name==name,sys.ps)]
16+
rename(x,renamespace(sys.name,x.name))()
17+
end
18+
19+
controls(sys::AbstractControlSystem) = isempty(sys.systems) ? sys.controls : [sys.controls;reduce(vcat,namespace_controls.(sys.systems))]
20+
21+
"""
22+
$(TYPEDEF)
23+
24+
A system describing an optimal control problem. This contains a loss function
25+
and ordinary differential equations with control variables that describe the
26+
dynamics.
27+
28+
# Fields
29+
$(FIELDS)
30+
31+
# Example
32+
33+
```
34+
using ModelingToolkit
35+
36+
@variables t x(t) v(t) u(t)
37+
@derivatives D'~t
38+
39+
loss = (4-x)^2 + 2v^2 + u^2
40+
eqs = [
41+
D(x) ~ v
42+
D(v) ~ u^3
43+
]
44+
45+
sys = ControlSystem(loss,eqs,t,[x,v],[u],[])
46+
```
47+
"""
48+
struct ControlSystem <: AbstractControlSystem
49+
"""The Loss function"""
50+
loss::Operation
51+
"""The ODEs defining the system."""
52+
eqs::Vector{Equation}
53+
"""Independent variable."""
54+
iv::Variable
55+
"""Dependent (state) variables."""
56+
states::Vector{Variable}
57+
"""Control variables."""
58+
controls::Vector{Variable}
59+
"""Parameter variables."""
60+
ps::Vector{Variable}
61+
pins::Vector{Variable}
62+
observed::Vector{Equation}
63+
"""
64+
Name: the name of the system
65+
"""
66+
name::Symbol
67+
"""
68+
systems: The internal systems
69+
"""
70+
systems::Vector{ControlSystem}
71+
end
72+
73+
function ControlSystem(loss, deqs::AbstractVector{<:Equation}, iv, dvs, controls, ps;
74+
pins = Variable[],
75+
observed = Operation[],
76+
systems = ODESystem[],
77+
name=gensym(:ControlSystem))
78+
iv′ = convert(Variable,iv)
79+
dvs′ = convert.(Variable,dvs)
80+
controls′ = convert.(Variable,controls)
81+
ps′ = convert.(Variable,ps)
82+
ControlSystem(loss, deqs, iv′, dvs′, controls′, ps′, pins, observed, name, systems)
83+
end
84+
85+
struct ControlToExpr
86+
sys::AbstractControlSystem
87+
states::Vector{Variable}
88+
controls::Vector{Variable}
89+
end
90+
ControlToExpr(@nospecialize(sys)) = ControlToExpr(sys,states(sys),controls(sys))
91+
function (f::ControlToExpr)(O::Operation)
92+
if isa(O.op, Variable)
93+
isequal(O.op, f.sys.iv) && return O.op.name # independent variable
94+
O.op f.states && return O.op.name # dependent variables
95+
O.op f.controls && return O.op.name # control variables
96+
isempty(O.args) && return O.op.name # 0-ary parameters
97+
return build_expr(:call, Any[O.op.name; f.(O.args)])
98+
end
99+
return build_expr(:call, Any[Symbol(O.op); f.(O.args)])
100+
end
101+
(f::ControlToExpr)(x) = convert(Expr, x)
102+
103+
function constructRadauIIA5(T::Type = Float64)
104+
sq6 = sqrt(6)
105+
A = [11//45-7sq6/360 37//225-169sq6/1800 -2//225+sq6/75
106+
37//225+169sq6/1800 11//45+7sq6/360 -2//225-sq6/75
107+
4//9-sq6/36 4//9+sq6/36 1//9]
108+
c = [2//5-sq6/10;2/5+sq6/10;1]
109+
α = [4//9-sq6/36;4//9+sq6/36;1//9]
110+
A = map(T,A)
111+
α = map(T,α)
112+
c = map(T,c)
113+
return(DiffEqBase.ImplicitRKTableau(A,c,α,5))
114+
end
115+
116+
117+
"""
118+
```julia
119+
runge_kutta_discretize(sys::ControlSystem,dt,tspan;
120+
tab = ModelingToolkit.constructRadauIIA5())
121+
```
122+
123+
Transforms a nonlinear optimal control problem into a constrained
124+
`OptimizationProblem` according to a Runge-Kutta tableau that describes
125+
a collocation method. Requires a fixed `dt` over a given `timespan`.
126+
Defaults to using the 5th order RadauIIA tableau, and altnerative tableaus
127+
can be specified using the SciML tableau style.
128+
"""
129+
function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
130+
tab = ModelingToolkit.constructRadauIIA5())
131+
n = length(tspan[1]:dt:tspan[2]) - 1
132+
m = length(tab.α)
133+
134+
f = eval(build_function([x.rhs for x in equations(sys)],sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys))[1])
135+
L = eval(build_function(sys.loss,sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys)))
136+
137+
# Expand out all of the variables in time and by stages
138+
timed_vars = [[Variable(x.name,i)(sys.iv()) for i in 1:n+1] for x in states(sys)]
139+
k_vars = [[Variable(Symbol(:ᵏ,x.name),i,j)(sys.iv()) for i in 1:m, j in 1:n] for x in states(sys)]
140+
states_timeseries = [getindex.(timed_vars,j) for j in 1:n+1]
141+
k_timeseries = [[getindex.(k_vars,i,j) for i in 1:m] for j in 1:n]
142+
control_timeseries = [[[Variable(x.name,i,j)(sys.iv()) for x in controls(sys)] for i in 1:m] for j in 1:n]
143+
ps = parameters(sys)
144+
iv = sys.iv()
145+
146+
# Calculate all of the update and stage equations
147+
mult = [tab.A * k_timeseries[i] for i in 1:n]
148+
tmps = [[states_timeseries[i] .+ mult[i][j] for j in 1:m] for i in 1:n]
149+
150+
bs = [states_timeseries[i] .+ dt .* sum(tab.α .* k_timeseries[i],dims=1)[1] for i in 1:n]
151+
updates = reduce(vcat,[states_timeseries[i+1] .~ bs[i] for i in 1:n])
152+
153+
df = [[dt .* Base.invokelatest(f,tmps[j][i],control_timeseries[j][i],ps,iv) for i in 1:m] for j in 1:n]
154+
stages = reduce(vcat,[k_timeseries[i][j] .~ df[i][j] for i in 1:n for j in 1:m])
155+
156+
# Enforce equalities in the controls
157+
control_equality = reduce(vcat,[control_timeseries[i][end] .~ control_timeseries[i+1][1] for i in 1:n-1])
158+
159+
# Create the loss function
160+
losses = [Base.invokelatest(L,states_timeseries[i],control_timeseries[i][1],(ps,),(iv,)) for i in 1:n]
161+
losses = vcat(losses,[Base.invokelatest(L,states_timeseries[n+1],control_timeseries[n][end],(ps,),(iv,))])
162+
163+
# Calculate final pieces
164+
equalities = vcat(stages,updates,control_equality)
165+
opt_states = vcat(reduce(vcat,reduce(vcat,states_timeseries)),reduce(vcat,reduce(vcat,k_timeseries)),reduce(vcat,reduce(vcat,control_timeseries)))
166+
167+
OptimizationSystem(reduce(+,losses),opt_states,ps,equality_constraints = equalities)
168+
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ struct OptimizationSystem <: AbstractSystem
2525
ps::Vector{Variable}
2626
pins::Vector{Variable}
2727
observed::Vector{Equation}
28+
equality_constraints::Vector{Equation}
29+
inequality_constraints::Vector{Operation}
2830
"""
2931
Name: the name of the system
3032
"""
@@ -38,9 +40,11 @@ end
3840
function OptimizationSystem(op, states, ps;
3941
pins = Variable[],
4042
observed = Operation[],
43+
equality_constraints = Equation[],
44+
inequality_constraints = Operation[],
4145
name = gensym(:OptimizationSystem),
4246
systems = OptimizationSystem[])
43-
OptimizationSystem(op, convert.(Variable,states), convert.(Variable,ps), pins, observed, name, systems)
47+
OptimizationSystem(op, convert.(Variable,states), convert.(Variable,ps), pins, observed, equality_constraints, inequality_constraints, name, systems)
4448
end
4549

4650
function calculate_gradient(sys::OptimizationSystem)

test/controlsystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using ModelingToolkit
2+
3+
@variables t x(t) v(t) u(t)
4+
@derivatives D'~t
5+
6+
loss = (4-x)^2 + 2v^2 + u^2
7+
eqs = [
8+
D(x) ~ v
9+
D(v) ~ u^3
10+
]
11+
12+
sys = ControlSystem(loss,eqs,t,[x,v],[u],[])
13+
dt = 0.1
14+
tspan = (0.0,1.0)
15+
runge_kutta_discretize(sys,dt,tspan)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using SafeTestsets, Test
1616
@safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end
1717
@safetestset "ReactionSystem Test" begin include("reactionsystem.jl") end
1818
@safetestset "JumpSystem Test" begin include("jumpsystem.jl") end
19+
@safetestset "ControlSystem Test" begin include("controlsystem.jl") end
1920
@safetestset "Build Targets Test" begin include("build_targets.jl") end
2021
@safetestset "Domain Test" begin include("domains.jl") end
2122
@safetestset "Constraints Test" begin include("constraints.jl") end

0 commit comments

Comments
 (0)