|
| 1 | +abstract type AbstractControlSystem <: AbstractSystem end |
| 2 | + |
| 3 | +function namespace_controls(sys::AbstractSystem) |
| 4 | + [rename(x,renamespace(sys.name,x.name)) for x in controls(sys)] |
| 5 | +end |
| 6 | + |
| 7 | +function controls(sys::AbstractControlSystem,args...) |
| 8 | + name = last(args) |
| 9 | + extra_names = reduce(Symbol,[Symbol(:₊,x.name) for x in args[1:end-1]]) |
| 10 | + newname = renamespace(extra_names,name) |
| 11 | + rename(x,renamespace(sys.name,newname))(sys.iv()) |
| 12 | +end |
| 13 | + |
| 14 | +function controls(sys::AbstractControlSystem,name::Symbol) |
| 15 | + x = sys.controls[findfirst(x->x.name==name,sys.ps)] |
| 16 | + rename(x,renamespace(sys.name,x.name))() |
| 17 | +end |
| 18 | + |
| 19 | +controls(sys::AbstractControlSystem) = isempty(sys.systems) ? sys.controls : [sys.controls;reduce(vcat,namespace_controls.(sys.systems))] |
| 20 | + |
| 21 | +""" |
| 22 | +$(TYPEDEF) |
| 23 | +
|
| 24 | +A system describing an optimal control problem. This contains a loss function |
| 25 | +and ordinary differential equations with control variables that describe the |
| 26 | +dynamics. |
| 27 | +
|
| 28 | +# Fields |
| 29 | +$(FIELDS) |
| 30 | +
|
| 31 | +# Example |
| 32 | +
|
| 33 | +``` |
| 34 | +using ModelingToolkit |
| 35 | +
|
| 36 | +@variables t x(t) v(t) u(t) |
| 37 | +@derivatives D'~t |
| 38 | +
|
| 39 | +loss = (4-x)^2 + 2v^2 + u^2 |
| 40 | +eqs = [ |
| 41 | + D(x) ~ v |
| 42 | + D(v) ~ u^3 |
| 43 | +] |
| 44 | +
|
| 45 | +sys = ControlSystem(loss,eqs,t,[x,v],[u],[]) |
| 46 | +``` |
| 47 | +""" |
| 48 | +struct ControlSystem <: AbstractControlSystem |
| 49 | + """The Loss function""" |
| 50 | + loss::Operation |
| 51 | + """The ODEs defining the system.""" |
| 52 | + eqs::Vector{Equation} |
| 53 | + """Independent variable.""" |
| 54 | + iv::Variable |
| 55 | + """Dependent (state) variables.""" |
| 56 | + states::Vector{Variable} |
| 57 | + """Control variables.""" |
| 58 | + controls::Vector{Variable} |
| 59 | + """Parameter variables.""" |
| 60 | + ps::Vector{Variable} |
| 61 | + pins::Vector{Variable} |
| 62 | + observed::Vector{Equation} |
| 63 | + """ |
| 64 | + Name: the name of the system |
| 65 | + """ |
| 66 | + name::Symbol |
| 67 | + """ |
| 68 | + systems: The internal systems |
| 69 | + """ |
| 70 | + systems::Vector{ControlSystem} |
| 71 | +end |
| 72 | + |
| 73 | +function ControlSystem(loss, deqs::AbstractVector{<:Equation}, iv, dvs, controls, ps; |
| 74 | + pins = Variable[], |
| 75 | + observed = Operation[], |
| 76 | + systems = ODESystem[], |
| 77 | + name=gensym(:ControlSystem)) |
| 78 | + iv′ = convert(Variable,iv) |
| 79 | + dvs′ = convert.(Variable,dvs) |
| 80 | + controls′ = convert.(Variable,controls) |
| 81 | + ps′ = convert.(Variable,ps) |
| 82 | + ControlSystem(loss, deqs, iv′, dvs′, controls′, ps′, pins, observed, name, systems) |
| 83 | +end |
| 84 | + |
| 85 | +struct ControlToExpr |
| 86 | + sys::AbstractControlSystem |
| 87 | + states::Vector{Variable} |
| 88 | + controls::Vector{Variable} |
| 89 | +end |
| 90 | +ControlToExpr(@nospecialize(sys)) = ControlToExpr(sys,states(sys),controls(sys)) |
| 91 | +function (f::ControlToExpr)(O::Operation) |
| 92 | + if isa(O.op, Variable) |
| 93 | + isequal(O.op, f.sys.iv) && return O.op.name # independent variable |
| 94 | + O.op ∈ f.states && return O.op.name # dependent variables |
| 95 | + O.op ∈ f.controls && return O.op.name # control variables |
| 96 | + isempty(O.args) && return O.op.name # 0-ary parameters |
| 97 | + return build_expr(:call, Any[O.op.name; f.(O.args)]) |
| 98 | + end |
| 99 | + return build_expr(:call, Any[Symbol(O.op); f.(O.args)]) |
| 100 | +end |
| 101 | +(f::ControlToExpr)(x) = convert(Expr, x) |
| 102 | + |
| 103 | +function constructRadauIIA5(T::Type = Float64) |
| 104 | + sq6 = sqrt(6) |
| 105 | + A = [11//45-7sq6/360 37//225-169sq6/1800 -2//225+sq6/75 |
| 106 | + 37//225+169sq6/1800 11//45+7sq6/360 -2//225-sq6/75 |
| 107 | + 4//9-sq6/36 4//9+sq6/36 1//9] |
| 108 | + c = [2//5-sq6/10;2/5+sq6/10;1] |
| 109 | + α = [4//9-sq6/36;4//9+sq6/36;1//9] |
| 110 | + A = map(T,A) |
| 111 | + α = map(T,α) |
| 112 | + c = map(T,c) |
| 113 | + return(DiffEqBase.ImplicitRKTableau(A,c,α,5)) |
| 114 | +end |
| 115 | + |
| 116 | + |
| 117 | +""" |
| 118 | +```julia |
| 119 | +runge_kutta_discretize(sys::ControlSystem,dt,tspan; |
| 120 | + tab = ModelingToolkit.constructRadauIIA5()) |
| 121 | +``` |
| 122 | +
|
| 123 | +Transforms a nonlinear optimal control problem into a constrained |
| 124 | +`OptimizationProblem` according to a Runge-Kutta tableau that describes |
| 125 | +a collocation method. Requires a fixed `dt` over a given `timespan`. |
| 126 | +Defaults to using the 5th order RadauIIA tableau, and altnerative tableaus |
| 127 | +can be specified using the SciML tableau style. |
| 128 | +""" |
| 129 | +function runge_kutta_discretize(sys::ControlSystem,dt,tspan; |
| 130 | + tab = ModelingToolkit.constructRadauIIA5()) |
| 131 | + n = length(tspan[1]:dt:tspan[2]) - 1 |
| 132 | + m = length(tab.α) |
| 133 | + |
| 134 | + f = eval(build_function([x.rhs for x in equations(sys)],sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys))[1]) |
| 135 | + L = eval(build_function(sys.loss,sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys))) |
| 136 | + |
| 137 | + # Expand out all of the variables in time and by stages |
| 138 | + timed_vars = [[Variable(x.name,i)(sys.iv()) for i in 1:n+1] for x in states(sys)] |
| 139 | + k_vars = [[Variable(Symbol(:ᵏ,x.name),i,j)(sys.iv()) for i in 1:m, j in 1:n] for x in states(sys)] |
| 140 | + states_timeseries = [getindex.(timed_vars,j) for j in 1:n+1] |
| 141 | + k_timeseries = [[getindex.(k_vars,i,j) for i in 1:m] for j in 1:n] |
| 142 | + control_timeseries = [[[Variable(x.name,i,j)(sys.iv()) for x in controls(sys)] for i in 1:m] for j in 1:n] |
| 143 | + ps = parameters(sys) |
| 144 | + iv = sys.iv() |
| 145 | + |
| 146 | + # Calculate all of the update and stage equations |
| 147 | + mult = [tab.A * k_timeseries[i] for i in 1:n] |
| 148 | + tmps = [[states_timeseries[i] .+ mult[i][j] for j in 1:m] for i in 1:n] |
| 149 | + |
| 150 | + bs = [states_timeseries[i] .+ dt .* sum(tab.α .* k_timeseries[i],dims=1)[1] for i in 1:n] |
| 151 | + updates = reduce(vcat,[states_timeseries[i+1] .~ bs[i] for i in 1:n]) |
| 152 | + |
| 153 | + df = [[dt .* Base.invokelatest(f,tmps[j][i],control_timeseries[j][i],ps,iv) for i in 1:m] for j in 1:n] |
| 154 | + stages = reduce(vcat,[k_timeseries[i][j] .~ df[i][j] for i in 1:n for j in 1:m]) |
| 155 | + |
| 156 | + # Enforce equalities in the controls |
| 157 | + control_equality = reduce(vcat,[control_timeseries[i][end] .~ control_timeseries[i+1][1] for i in 1:n-1]) |
| 158 | + |
| 159 | + # Create the loss function |
| 160 | + losses = [Base.invokelatest(L,states_timeseries[i],control_timeseries[i][1],(ps,),(iv,)) for i in 1:n] |
| 161 | + losses = vcat(losses,[Base.invokelatest(L,states_timeseries[n+1],control_timeseries[n][end],(ps,),(iv,))]) |
| 162 | + |
| 163 | + # Calculate final pieces |
| 164 | + equalities = vcat(stages,updates,control_equality) |
| 165 | + opt_states = vcat(reduce(vcat,reduce(vcat,states_timeseries)),reduce(vcat,reduce(vcat,k_timeseries)),reduce(vcat,reduce(vcat,control_timeseries))) |
| 166 | + |
| 167 | + OptimizationSystem(reduce(+,losses),opt_states,ps,equality_constraints = equalities) |
| 168 | +end |
0 commit comments