Skip to content

Commit b555c15

Browse files
committed
Adapt controlsystems for use of parameter arrays
1 parent c458946 commit b555c15

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/systems/control/controlsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
160160
control_equality = reduce(vcat,[control_timeseries[i][end] .~ control_timeseries[i+1][1] for i in 1:n-1])
161161

162162
# Create the loss function
163-
losses = [Base.invokelatest(L,states_timeseries[i],control_timeseries[i][1],(ps,),(iv,)) for i in 1:n]
164-
losses = vcat(losses,[Base.invokelatest(L,states_timeseries[n+1],control_timeseries[n][end],(ps,),(iv,))])
163+
losses = [Base.invokelatest(L,states_timeseries[i],control_timeseries[i][1],ps,iv) for i in 1:n]
164+
losses = vcat(losses,[Base.invokelatest(L,states_timeseries[n+1],control_timeseries[n][end],ps,iv)])
165165

166166
# Calculate final pieces
167167
equalities = vcat(stages,updates,control_equality)

test/controlsystem.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
using ModelingToolkit
22

33
@variables t x(t) v(t) u(t)
4-
@parameters p
4+
@parameters p[1:1]
55
@derivatives D'~t
66

77
loss = (4-x)^2 + 2v^2 + u^2
88
eqs = [
9-
D(x) ~ v
10-
D(v) ~ p*u^3
9+
D(x) ~ v #- p[2]*x
10+
D(v) ~ p[1]*u^3 + v
1111
]
1212

13-
sys = ControlSystem(loss,eqs,t,[x,v],[u],[p])
13+
sys = ControlSystem(loss,eqs,t,[x,v],[u],p)
1414
dt = 0.1
1515
tspan = (0.0,1.0)
1616
sys = runge_kutta_discretize(sys,dt,tspan)
1717

1818
u0 = rand(112) # guess for the state values
1919
prob = OptimizationProblem(sys,u0,[0.1],grad=true)
20+

0 commit comments

Comments
 (0)