Skip to content

Commit 538b64b

Browse files
start nonlinear optimal control
1 parent 95f14a3 commit 538b64b

File tree

6 files changed

+152
-3
lines changed

6 files changed

+152
-3
lines changed

src/ModelingToolkit.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ include("systems/nonlinear/nonlinearsystem.jl")
108108

109109
include("systems/optimization/optimizationsystem.jl")
110110

111+
include("systems/control/controlsystem.jl")
112+
111113
include("systems/pde/pdesystem.jl")
112114

113115
include("systems/reaction/reactionsystem.jl")
@@ -125,14 +127,16 @@ export OptimizationProblem, OptimizationProblemExpr
125127
export SteadyStateProblem, SteadyStateProblemExpr
126128
export JumpProblem, DiscreteProblem
127129
export NonlinearSystem, OptimizationSystem
130+
export ControlSystem
128131
export ode_order_lowering
132+
export runge_kutta_discretize
129133
export PDESystem
130134
export Reaction, ReactionSystem, ismassaction, oderatelaw, jumpratelaw
131135
export Differential, expand_derivatives, @derivatives
132136
export IntervalDomain, ProductDomain, , CircleDomain
133137
export Equation, ConstrainedEquation
134138
export Operation, Expression, Variable
135-
export independent_variable, states, parameters, equations, pins, observed
139+
export independent_variable, states, controls, parameters, equations, pins, observed
136140

137141
export calculate_jacobian, generate_jacobian, generate_function
138142
export calculate_tgrad, generate_tgrad

src/systems/abstractsystem.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ end
171171

172172
namespace_equations(sys::AbstractSystem) = namespace_equation.(equations(sys),sys.name,sys.iv.name)
173173

174-
175174
function namespace_equation(eq::Equation,name,ivname)
176175
_lhs = namespace_operation(eq.lhs,name,ivname)
177176
_rhs = namespace_operation(eq.rhs,name,ivname)

src/systems/control/controlsystem.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
abstract type AbstractControlSystem <: AbstractSystem end
2+
3+
function namespace_controls(sys::AbstractSystem)
4+
[rename(x,renamespace(sys.name,x.name)) for x in controls(sys)]
5+
end
6+
7+
function controls(sys::AbstractControlSystem,args...)
8+
name = last(args)
9+
extra_names = reduce(Symbol,[Symbol(:₊,x.name) for x in args[1:end-1]])
10+
newname = renamespace(extra_names,name)
11+
rename(x,renamespace(sys.name,newname))(sys.iv())
12+
end
13+
14+
function controls(sys::AbstractControlSystem,name::Symbol)
15+
x = sys.controls[findfirst(x->x.name==name,sys.ps)]
16+
rename(x,renamespace(sys.name,x.name))()
17+
end
18+
19+
controls(sys::AbstractControlSystem) = isempty(sys.systems) ? sys.controls : [sys.controls;reduce(vcat,namespace_controls.(sys.systems))]
20+
21+
struct ControlSystem <: AbstractControlSystem
22+
"""The Loss function"""
23+
loss::Operation
24+
"""The ODEs defining the system."""
25+
eqs::Vector{Equation}
26+
"""Independent variable."""
27+
iv::Variable
28+
"""Dependent (state) variables."""
29+
states::Vector{Variable}
30+
"""Control variables."""
31+
controls::Vector{Variable}
32+
"""Parameter variables."""
33+
ps::Vector{Variable}
34+
pins::Vector{Variable}
35+
observed::Vector{Equation}
36+
"""
37+
Name: the name of the system
38+
"""
39+
name::Symbol
40+
"""
41+
systems: The internal systems
42+
"""
43+
systems::Vector{ControlSystem}
44+
end
45+
46+
function ControlSystem(loss, deqs::AbstractVector{<:Equation}, iv, dvs, controls, ps;
47+
pins = Variable[],
48+
observed = Operation[],
49+
systems = ODESystem[],
50+
name=gensym(:ControlSystem))
51+
iv′ = convert(Variable,iv)
52+
dvs′ = convert.(Variable,dvs)
53+
controls′ = convert.(Variable,controls)
54+
ps′ = convert.(Variable,ps)
55+
ControlSystem(loss, deqs, iv′, dvs′, controls′, ps′, pins, observed, name, systems)
56+
end
57+
58+
struct ControlToExpr
59+
sys::AbstractControlSystem
60+
states::Vector{Variable}
61+
controls::Vector{Variable}
62+
end
63+
ControlToExpr(@nospecialize(sys)) = ControlToExpr(sys,states(sys),controls(sys))
64+
function (f::ControlToExpr)(O::Operation)
65+
if isa(O.op, Variable)
66+
isequal(O.op, f.sys.iv) && return O.op.name # independent variable
67+
O.op f.states && return O.op.name # dependent variables
68+
O.op f.controls && return O.op.name # control variables
69+
isempty(O.args) && return O.op.name # 0-ary parameters
70+
return build_expr(:call, Any[O.op.name; f.(O.args)])
71+
end
72+
return build_expr(:call, Any[Symbol(O.op); f.(O.args)])
73+
end
74+
(f::ControlToExpr)(x) = convert(Expr, x)
75+
76+
function constructRadauIIA5(T::Type = Float64)
77+
sq6 = sqrt(6)
78+
A = [11//45-7sq6/360 37//225-169sq6/1800 -2//225+sq6/75
79+
37//225+169sq6/1800 11//45+7sq6/360 -2//225-sq6/75
80+
4//9-sq6/36 4//9+sq6/36 1//9]
81+
c = [2//5-sq6/10;2/5+sq6/10;1]
82+
α = [4//9-sq6/36;4//9+sq6/36;1//9]
83+
A = map(T,A)
84+
α = map(T,α)
85+
c = map(T,c)
86+
return(DiffEqBase.ImplicitRKTableau(A,c,α,5))
87+
end
88+
89+
function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
90+
tab = ModelingToolkit.constructRadauIIA5())
91+
n = length(tspan[1]:dt:tspan[2]) - 1
92+
m = length(tab.α)
93+
94+
f = eval(build_function([x.rhs for x in equations(sys)],sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys))[1])
95+
L = eval(build_function(sys.loss,sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys)))
96+
97+
# Expand out all of the variables in time and by stages
98+
timed_vars = [[Variable(x.name,i)(sys.iv()) for i in 1:n+1] for x in states(sys)]
99+
k_vars = [[Variable(Symbol(:ᵏ,x.name),i,j)(sys.iv()) for i in 1:m, j in 1:n] for x in states(sys)]
100+
states_timeseries = [getindex.(timed_vars,j) for j in 1:n+1]
101+
k_timeseries = [[getindex.(k_vars,i,j) for i in 1:m] for j in 1:n]
102+
control_timeseries = [[[Variable(x.name,i,j)(sys.iv()) for x in controls(sys)] for i in 1:m] for j in 1:n]
103+
ps = parameters(sys)
104+
iv = sys.iv()
105+
106+
# Calculate all of the update and stage equations
107+
mult = [tab.A * k_timeseries[i] for i in 1:n]
108+
tmps = [[states_timeseries[i] .+ mult[i][j] for j in 1:m] for i in 1:n]
109+
110+
bs = [states_timeseries[i] .+ dt .* sum(tab.α .* k_timeseries[i],dims=1)[1] for i in 1:n]
111+
updates = reduce(vcat,[states_timeseries[i+1] .~ bs[i] for i in 1:n])
112+
113+
df = [[dt .* Base.invokelatest(f,tmps[j][i],control_timeseries[j][i],ps,iv) for i in 1:m] for j in 1:n]
114+
stages = reduce(vcat,[k_timeseries[i][j] .~ df[i][j] for i in 1:n for j in 1:m])
115+
116+
# Enforce equalities in the controls
117+
control_equality = reduce(vcat,[control_timeseries[i][end] .~ control_timeseries[i+1][1] for i in 1:n-1])
118+
119+
# Create the loss function
120+
losses = [Base.invokelatest(L,states_timeseries[i],control_timeseries[i][1],(ps,),(iv,)) for i in 1:n]
121+
losses = vcat(losses,[Base.invokelatest(L,states_timeseries[n+1],control_timeseries[n][end],(ps,),(iv,))])
122+
123+
# Calculate final pieces
124+
equalities = vcat(stages,updates,control_equality)
125+
opt_states = vcat(reduce(vcat,reduce(vcat,states_timeseries)),reduce(vcat,reduce(vcat,k_timeseries)),reduce(vcat,reduce(vcat,control_timeseries)))
126+
127+
OptimizationSystem(reduce(+,losses),opt_states,ps,constraints = equalities)
128+
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct OptimizationSystem <: AbstractSystem
2525
ps::Vector{Variable}
2626
pins::Vector{Variable}
2727
observed::Vector{Equation}
28+
constraints::Vector{Equation}
2829
"""
2930
Name: the name of the system
3031
"""
@@ -38,9 +39,10 @@ end
3839
function OptimizationSystem(op, states, ps;
3940
pins = Variable[],
4041
observed = Operation[],
42+
constraints = Equation[],
4143
name = gensym(:OptimizationSystem),
4244
systems = OptimizationSystem[])
43-
OptimizationSystem(op, convert.(Variable,states), convert.(Variable,ps), pins, observed, name, systems)
45+
OptimizationSystem(op, convert.(Variable,states), convert.(Variable,ps), pins, observed, constraints, name, systems)
4446
end
4547

4648
function calculate_gradient(sys::OptimizationSystem)

test/controlsystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using ModelingToolkit
2+
3+
@variables t x(t) v(t) u(t)
4+
@derivatives D'~t
5+
6+
loss = (4-x)^2 + 2v^2 + u^2
7+
eqs = [
8+
D(x) ~ v
9+
D(v) ~ u^3
10+
]
11+
12+
sys = ControlSystem(loss,eqs,t,[x,v],[u],[])
13+
dt = 0.1
14+
tspan = (0.0,1.0)
15+
runge_kutta_discretize(sys,dt,tspan)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using SafeTestsets, Test
1616
@safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end
1717
@safetestset "ReactionSystem Test" begin include("reactionsystem.jl") end
1818
@safetestset "JumpSystem Test" begin include("jumpsystem.jl") end
19+
@safetestset "ControlSystem Test" begin include("controlsystem.jl") end
1920
@safetestset "Build Targets Test" begin include("build_targets.jl") end
2021
@safetestset "Domain Test" begin include("domains.jl") end
2122
@safetestset "Constraints Test" begin include("constraints.jl") end

0 commit comments

Comments
 (0)