@@ -47,18 +47,18 @@ sys = ControlSystem(loss,eqs,t,[x,v],[u],[])
47
47
"""
48
48
struct ControlSystem <: AbstractControlSystem
49
49
""" The Loss function"""
50
- loss:: Operation
50
+ loss:: Term
51
51
""" The ODEs defining the system."""
52
52
eqs:: Vector{Equation}
53
53
""" Independent variable."""
54
- iv:: Variable
54
+ iv:: Sym
55
55
""" Dependent (state) variables."""
56
- states:: Vector{Variable}
56
+ states:: Vector
57
57
""" Control variables."""
58
- controls:: Vector{Variable}
58
+ controls:: Vector
59
59
""" Parameter variables."""
60
- ps:: Vector{Variable}
61
- pins:: Vector{Variable}
60
+ ps:: Vector
61
+ pins:: Vector
62
62
observed:: Vector{Equation}
63
63
"""
64
64
Name: the name of the system
@@ -71,34 +71,35 @@ struct ControlSystem <: AbstractControlSystem
71
71
end
72
72
73
73
function ControlSystem (loss, deqs:: AbstractVector{<:Equation} , iv, dvs, controls, ps;
74
- pins = Variable [],
75
- observed = Operation [],
74
+ pins = [],
75
+ observed = [],
76
76
systems = ODESystem[],
77
77
name= gensym (:ControlSystem ))
78
- iv′ = convert (Variable,iv)
79
- dvs′ = convert .(Variable,dvs)
80
- controls′ = convert .(Variable,controls)
81
- ps′ = convert .(Variable,ps)
82
- ControlSystem (loss, deqs, iv′, dvs′, controls′, ps′, pins, observed, name, systems)
78
+ iv′ = value (iv)
79
+ dvs′ = value .(dvs)
80
+ controls′ = value .(controls)
81
+ ps′ = value .(ps)
82
+ ControlSystem (value (loss), deqs, iv′, dvs′, controls′,
83
+ ps′, pins, observed, name, systems)
83
84
end
84
85
85
86
struct ControlToExpr
86
87
sys:: AbstractControlSystem
87
- states:: Vector{Variable}
88
- controls:: Vector{Variable}
88
+ states:: Vector
89
+ controls:: Vector
89
90
end
90
91
ControlToExpr (@nospecialize (sys)) = ControlToExpr (sys,states (sys),controls (sys))
91
- function (f:: ControlToExpr )(O:: Operation )
92
- if isa (O. op, Variable )
93
- isequal (O. op , f. sys . iv) && return O. op. name # independent variable
94
- O . op ∈ f . states && return O. op. name # dependent variables
95
- O. op ∈ f . controls && return O . op . name # control variables
96
- isempty (O . args) && return O . op . name # 0-ary parameters
97
- return build_expr (:call , Any[O. op. name ; f .(O. args)])
92
+ function (f:: ControlToExpr )(O:: Term )
93
+ res = if isa (O. op, Sym )
94
+ any ( isequal (O) , f. states) && return O. op. name # dependent variables
95
+ any ( isequal (O), f . controls) && return O. op. name # control variables
96
+ build_expr ( :call , Any[ O. op. name; f .(O . args)])
97
+ else
98
+ build_expr (:call , Any[Symbol ( O. op) ; f .(O. args)])
98
99
end
99
- return build_expr (:call , Any[Symbol (O. op); f .(O. args)])
100
100
end
101
- (f:: ControlToExpr )(x) = convert (Expr, x)
101
+ (f:: ControlToExpr )(x:: Sym ) = x. name
102
+ (f:: ControlToExpr )(x) = x
102
103
103
104
function constructRadauIIA5 (T:: Type = Float64)
104
105
sq6 = sqrt (6 )
@@ -134,20 +135,22 @@ function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
134
135
f = eval (build_function ([x. rhs for x in equations (sys)],sys. states,sys. controls,sys. ps,sys. iv,conv = ModelingToolkit. ControlToExpr (sys))[1 ])
135
136
L = eval (build_function (sys. loss,sys. states,sys. controls,sys. ps,sys. iv,conv = ModelingToolkit. ControlToExpr (sys)))
136
137
138
+ var (n, i... ) = var (nameof (n), i... )
139
+ var (n:: Symbol , i... ) = Sym {FnType{Tuple{symtype(sys.iv)}, Number}} (nameof (Variable (n, i... )))
137
140
# Expand out all of the variables in time and by stages
138
- timed_vars = [[Variable (x. name ,i)(sys. iv () ) for i in 1 : n+ 1 ] for x in states (sys)]
139
- k_vars = [[Variable (Symbol (:ᵏ ,x . name) ,i,j)(sys. iv () ) for i in 1 : m, j in 1 : n] for x in states (sys)]
141
+ timed_vars = [[var (x. op ,i)(sys. iv) for i in 1 : n+ 1 ] for x in states (sys)]
142
+ k_vars = [[var (Symbol (:ᵏ ,nameof (x . op)) ,i,j)(sys. iv) for i in 1 : m, j in 1 : n] for x in states (sys)]
140
143
states_timeseries = [getindex .(timed_vars,j) for j in 1 : n+ 1 ]
141
- k_timeseries = [[getindex .(k_vars,i,j) for i in 1 : m] for j in 1 : n]
142
- control_timeseries = [[[Variable (x. name ,i,j)(sys. iv () ) for x in controls (sys)] for i in 1 : m] for j in 1 : n]
144
+ k_timeseries = [[Num .( getindex .(k_vars,i,j) ) for i in 1 : m] for j in 1 : n]
145
+ control_timeseries = [[[var (x. op ,i,j)(sys. iv) for x in controls (sys)] for i in 1 : m] for j in 1 : n]
143
146
ps = parameters (sys)
144
- iv = sys. iv ()
147
+ iv = sys. iv
145
148
146
149
# Calculate all of the update and stage equations
147
150
mult = [tab. A * k_timeseries[i] for i in 1 : n]
148
151
tmps = [[states_timeseries[i] .+ mult[i][j] for j in 1 : m] for i in 1 : n]
149
152
150
- bs = [states_timeseries[i] .+ dt .* sum ( tab. α .* k_timeseries[i],dims= 1 )[1 ] for i in 1 : n]
153
+ bs = [states_timeseries[i] .+ dt .* reduce ( + , tab. α .* k_timeseries[i],dims= 1 )[1 ] for i in 1 : n]
151
154
updates = reduce (vcat,[states_timeseries[i+ 1 ] .~ bs[i] for i in 1 : n])
152
155
153
156
df = [[dt .* Base. invokelatest (f,tmps[j][i],control_timeseries[j][i],ps,iv) for i in 1 : m] for j in 1 : n]
@@ -164,5 +167,5 @@ function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
164
167
equalities = vcat (stages,updates,control_equality)
165
168
opt_states = vcat (reduce (vcat,reduce (vcat,states_timeseries)),reduce (vcat,reduce (vcat,k_timeseries)),reduce (vcat,reduce (vcat,control_timeseries)))
166
169
167
- OptimizationSystem (reduce (+ ,losses),opt_states,ps,equality_constraints = equalities)
170
+ OptimizationSystem (reduce (+ ,losses, init = 0 ),opt_states,ps,equality_constraints = equalities)
168
171
end
0 commit comments