Skip to content

Commit 4568245

Browse files
authored
Merge pull request #3 from control-toolbox/2-gpu-testing
2 gpu testing
2 parents 6b68959 + f045d0c commit 4568245

File tree

3 files changed

+131
-1
lines changed

3 files changed

+131
-1
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ main.log
44
main.out
55
main.pdf
66
main.thm
7-
7+
Manifest.toml

Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[deps]
2+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4+
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
5+
MadNLP = "2621e9c9-9eb4-46b1-8089-e8c72242dfb6"
6+
MadNLPGPU = "d72a61cc-809d-412f-99be-fd81f4b8a598"
7+
MadNLPMumps = "3b83494e-c0a4-4895-918b-9157a7a085a1"
8+
OptimalControl = "5f98b655-cc9a-415a-b60e-744165666948"

gpu-test.jl

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# goddard.jl
2+
3+
using OptimalControl
4+
using MadNLPMumps
5+
using MadNLPGPU
6+
using CUDA
7+
using BenchmarkTools
8+
using Interpolations
9+
import Logging
10+
11+
Logging.disable_logging(Logging.Warn) # disable warnings
12+
13+
# Goddard (bang-singular-boundary-bang structure)
14+
15+
function goddard()
16+
17+
r0 = 1.0
18+
v0 = 0.0
19+
m0 = 1.0
20+
vmax = 0.1
21+
mf = 0.6
22+
Cd = 310.0
23+
Tmax = 3.5
24+
β = 500.0
25+
b = 2.0
26+
27+
o = @def begin
28+
29+
tf R, variable
30+
t [0, tf], time
31+
x = (r, v, m) R³, state
32+
u R, control
33+
34+
x(0) == [r0, v0, m0]
35+
m(tf) == mf
36+
0 u(t) 1
37+
r(t) r0
38+
0 v(t) vmax
39+
40+
(r)(t) == v(t)
41+
(v)(t) == -Cd * v(t)^2 * exp(-β * (r(t) - 1)) / m(t) - 1 / r(t)^2 + u(t) * Tmax / m(t)
42+
(m)(t) == -b * Tmax * u(t)
43+
44+
r(tf) max
45+
46+
end
47+
48+
return o
49+
50+
end
51+
52+
# Quadrotor
53+
54+
function quadrotor()
55+
56+
T = 1
57+
g = 9.8
58+
r = 0.1
59+
60+
o = @def begin
61+
62+
t [0, T], time
63+
x R⁹, state
64+
u R⁴, control
65+
66+
x(0) == zeros(9)
67+
68+
(x₁)(t) == x₂(t)
69+
(x₂)(t) == u₁(t) * cos(x₇(t)) * sin(x₈(t)) * cos(x₉(t)) + u₁(t) * sin(x₇(t)) * sin(x₉(t))
70+
(x₃)(t) == x₄(t)
71+
(x₄)(t) == u₁(t) * cos(x₇(t)) * sin(x₈(t)) * sin(x₉(t)) - u₁(t) * sin(x₇(t)) * cos(x₉(t))
72+
(x₅)(t) == x₆(t)
73+
(x₆)(t) == u₁(t) * cos(x₇(t)) * cos(x₈(t)) - g
74+
(x₇)(t) == u₂(t) * cos(x₇(t)) / cos(x₈(t)) + u₃(t) * sin(x₇(t)) / cos(x₈(t))
75+
(x₈)(t) ==-u₂(t) * sin(x₇(t)) + u₃(t) * cos(x₇(t))
76+
(x₉)(t) == u₂(t) * cos(x₇(t)) * tan(x₈(t)) + u₃(t) * sin(x₇(t)) * tan(x₈(t)) + u₄(t)
77+
78+
dt1 = sin(2π * t / T)
79+
df1 = 0
80+
dt3 = 2sin(4π * t / T)
81+
df3 = 0
82+
dt5 = 2t / T
83+
df5 = 2
84+
85+
0.5( (x₁(t) - dt1)^2 + (x₃(t) - dt3)^2 + (x₅(t) - dt5)^2 + x₇(t)^2 + x₈(t)^2 + x₉(t)^2 +
86+
r * (u₁(t)^2 + u₂(t)^2 + u₃(t)^2 + u₄(t)^2) ) min
87+
88+
end
89+
90+
return o
91+
92+
end
93+
94+
# Solving
95+
96+
tol = 1e-7
97+
print_level = MadNLP.WARN
98+
#print_level = MadNLP.INFO
99+
100+
for (name, o) [("goddard", goddard()), ("quadrotor", quadrotor())]
101+
printstyled("\nProblem: $name\n"; bold=true)
102+
for N (100, 200, 500, 750, 1000, 2000, 5000, 7500, 10000, 20000, 50000, 75000, 100000)
103+
104+
m_cpu = model(direct_transcription(o, :exa; grid_size=N))
105+
m_gpu = model(direct_transcription(o, :exa; grid_size=N, exa_backend=CUDABackend()))
106+
printstyled("\nsolver = MadNLP", ", N = ", N, "\n"; bold=true)
107+
print("CPU:")
108+
try sol = @btime madnlp($m_cpu; print_level=$print_level, tol=$tol, linear_solver=MumpsSolver)
109+
println(" converged: ", sol.status == MadNLP.Status(1), ", iter: ", sol.iter)
110+
catch ex
111+
println("\n error: ", ex)
112+
end
113+
CUDA.functional() || throw("CUDA not available")
114+
print("GPU:")
115+
try madnlp(m_gpu; print_level=print_level, tol=tol);
116+
sol = CUDA.@time madnlp(m_gpu; print_level=print_level, tol=tol)
117+
println(" converged: ", sol.status == MadNLP.Status(1), ", iter: ", sol.iter)
118+
catch ex
119+
println("\n error: ", ex)
120+
end
121+
122+
end; end

0 commit comments

Comments
 (0)