Skip to content

Commit b8b0941

Browse files
committed
controlsystem fixes
1 parent 22bbafc commit b8b0941

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

src/systems/control/controlsystem.jl

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,18 @@ sys = ControlSystem(loss,eqs,t,[x,v],[u],[])
4747
"""
4848
struct ControlSystem <: AbstractControlSystem
4949
"""The Loss function"""
50-
loss::Operation
50+
loss::Term
5151
"""The ODEs defining the system."""
5252
eqs::Vector{Equation}
5353
"""Independent variable."""
54-
iv::Variable
54+
iv::Sym
5555
"""Dependent (state) variables."""
56-
states::Vector{Variable}
56+
states::Vector
5757
"""Control variables."""
58-
controls::Vector{Variable}
58+
controls::Vector
5959
"""Parameter variables."""
60-
ps::Vector{Variable}
61-
pins::Vector{Variable}
60+
ps::Vector
61+
pins::Vector
6262
observed::Vector{Equation}
6363
"""
6464
Name: the name of the system
@@ -71,34 +71,35 @@ struct ControlSystem <: AbstractControlSystem
7171
end
7272

7373
function ControlSystem(loss, deqs::AbstractVector{<:Equation}, iv, dvs, controls, ps;
74-
pins = Variable[],
75-
observed = Operation[],
74+
pins = [],
75+
observed = [],
7676
systems = ODESystem[],
7777
name=gensym(:ControlSystem))
78-
iv′ = convert(Variable,iv)
79-
dvs′ = convert.(Variable,dvs)
80-
controls′ = convert.(Variable,controls)
81-
ps′ = convert.(Variable,ps)
82-
ControlSystem(loss, deqs, iv′, dvs′, controls′, ps′, pins, observed, name, systems)
78+
iv′ = value(iv)
79+
dvs′ = value.(dvs)
80+
controls′ = value.(controls)
81+
ps′ = value.(ps)
82+
ControlSystem(value(loss), deqs, iv′, dvs′, controls′,
83+
ps′, pins, observed, name, systems)
8384
end
8485

8586
struct ControlToExpr
8687
sys::AbstractControlSystem
87-
states::Vector{Variable}
88-
controls::Vector{Variable}
88+
states::Vector
89+
controls::Vector
8990
end
9091
ControlToExpr(@nospecialize(sys)) = ControlToExpr(sys,states(sys),controls(sys))
91-
function (f::ControlToExpr)(O::Operation)
92-
if isa(O.op, Variable)
93-
isequal(O.op, f.sys.iv) && return O.op.name # independent variable
94-
O.op f.states && return O.op.name # dependent variables
95-
O.op f.controls && return O.op.name # control variables
96-
isempty(O.args) && return O.op.name # 0-ary parameters
97-
return build_expr(:call, Any[O.op.name; f.(O.args)])
92+
function (f::ControlToExpr)(O::Term)
93+
res = if isa(O.op, Sym)
94+
any(isequal(O), f.states) && return O.op.name # dependent variables
95+
any(isequal(O), f.controls) && return O.op.name # control variables
96+
build_expr(:call, Any[O.op.name; f.(O.args)])
97+
else
98+
build_expr(:call, Any[Symbol(O.op); f.(O.args)])
9899
end
99-
return build_expr(:call, Any[Symbol(O.op); f.(O.args)])
100100
end
101-
(f::ControlToExpr)(x) = convert(Expr, x)
101+
(f::ControlToExpr)(x::Sym) = x.name
102+
(f::ControlToExpr)(x) = x
102103

103104
function constructRadauIIA5(T::Type = Float64)
104105
sq6 = sqrt(6)
@@ -134,20 +135,22 @@ function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
134135
f = eval(build_function([x.rhs for x in equations(sys)],sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys))[1])
135136
L = eval(build_function(sys.loss,sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys)))
136137

138+
var(n, i...) = var(nameof(n), i...)
139+
var(n::Symbol, i...) = Sym{FnType{Tuple{symtype(sys.iv)}, Number}}(nameof(Variable(n, i...)))
137140
# Expand out all of the variables in time and by stages
138-
timed_vars = [[Variable(x.name,i)(sys.iv()) for i in 1:n+1] for x in states(sys)]
139-
k_vars = [[Variable(Symbol(:ᵏ,x.name),i,j)(sys.iv()) for i in 1:m, j in 1:n] for x in states(sys)]
141+
timed_vars = [[var(x.op,i)(sys.iv) for i in 1:n+1] for x in states(sys)]
142+
k_vars = [[var(Symbol(:ᵏ,nameof(x.op)),i,j)(sys.iv) for i in 1:m, j in 1:n] for x in states(sys)]
140143
states_timeseries = [getindex.(timed_vars,j) for j in 1:n+1]
141-
k_timeseries = [[getindex.(k_vars,i,j) for i in 1:m] for j in 1:n]
142-
control_timeseries = [[[Variable(x.name,i,j)(sys.iv()) for x in controls(sys)] for i in 1:m] for j in 1:n]
144+
k_timeseries = [[Num.(getindex.(k_vars,i,j)) for i in 1:m] for j in 1:n]
145+
control_timeseries = [[[var(x.op,i,j)(sys.iv) for x in controls(sys)] for i in 1:m] for j in 1:n]
143146
ps = parameters(sys)
144-
iv = sys.iv()
147+
iv = sys.iv
145148

146149
# Calculate all of the update and stage equations
147150
mult = [tab.A * k_timeseries[i] for i in 1:n]
148151
tmps = [[states_timeseries[i] .+ mult[i][j] for j in 1:m] for i in 1:n]
149152

150-
bs = [states_timeseries[i] .+ dt .* sum(tab.α .* k_timeseries[i],dims=1)[1] for i in 1:n]
153+
bs = [states_timeseries[i] .+ dt .* reduce(+, tab.α .* k_timeseries[i],dims=1)[1] for i in 1:n]
151154
updates = reduce(vcat,[states_timeseries[i+1] .~ bs[i] for i in 1:n])
152155

153156
df = [[dt .* Base.invokelatest(f,tmps[j][i],control_timeseries[j][i],ps,iv) for i in 1:m] for j in 1:n]
@@ -164,5 +167,5 @@ function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
164167
equalities = vcat(stages,updates,control_equality)
165168
opt_states = vcat(reduce(vcat,reduce(vcat,states_timeseries)),reduce(vcat,reduce(vcat,k_timeseries)),reduce(vcat,reduce(vcat,control_timeseries)))
166169

167-
OptimizationSystem(reduce(+,losses),opt_states,ps,equality_constraints = equalities)
170+
OptimizationSystem(reduce(+,losses, init=0),opt_states,ps,equality_constraints = equalities)
168171
end

0 commit comments

Comments
 (0)