|
15 | 15 |
|
16 | 16 | using namespace aligator::gar; |
17 | 17 |
|
18 | | -const uint nx = 36; |
19 | | -const uint nu = 12; |
| 18 | +static constexpr uint nx = 36; |
| 19 | +static constexpr uint nu = 12; |
| 20 | +static constexpr uint nc = 32; |
| 21 | +static constexpr double mueq = 1e-11; |
| 22 | +static std::mt19937 rng; |
| 23 | +static normal_unary_op normal_op{rng}; |
20 | 24 |
|
21 | 25 | static void BM_serial(benchmark::State &state) { |
22 | 26 | uint horz = (uint)state.range(0); |
23 | | - VectorXs x0 = VectorXs::NullaryExpr(nx, normal_unary_op{}); |
24 | | - const LqrProblemTpl<double> problem = generate_problem(x0, horz, nx, nu); |
| 27 | + VectorXs x0 = VectorXs::NullaryExpr(nx, normal_op); |
| 28 | + const LqrProblemTpl<double> problem = |
| 29 | + generateLqProblem(rng, x0, horz, nx, nu, 0, nc); |
25 | 30 | ProximalRiccatiSolver<double> solver(problem); |
26 | | - const double mu = 1e-11; |
27 | 31 | auto [xs, us, vs, lbdas] = lqrInitializeSolution(problem); |
28 | 32 | for (auto _ : state) { |
29 | | - solver.backward(mu, mu); |
| 33 | + solver.backward(mueq); |
30 | 34 | solver.forward(xs, us, vs, lbdas); |
31 | 35 | } |
32 | 36 | } |
33 | 37 |
|
34 | 38 | #ifdef ALIGATOR_MULTITHREADING |
35 | 39 | template <uint NPROC> static void BM_parallel(benchmark::State &state) { |
36 | 40 | uint horz = (uint)state.range(0); |
37 | | - VectorXs x0 = VectorXs::NullaryExpr(nx, normal_unary_op{}); |
38 | | - LqrProblemTpl<double> problem = generate_problem(x0, horz, nx, nu); |
| 41 | + VectorXs x0 = VectorXs::NullaryExpr(nx, normal_op); |
| 42 | + LqrProblemTpl<double> problem = |
| 43 | + generateLqProblem(rng, x0, horz, nx, nu, 0, nc); |
39 | 44 | ParallelRiccatiSolver<double> solver(problem, NPROC); |
40 | | - const double mu = 1e-11; |
41 | 45 | auto [xs, us, vs, lbdas] = lqrInitializeSolution(problem); |
42 | 46 | for (auto _ : state) { |
43 | | - solver.backward(mu, mu); |
| 47 | + solver.backward(mueq); |
44 | 48 | solver.forward(xs, us, vs, lbdas); |
45 | 49 | } |
46 | 50 | } |
47 | 51 | #endif |
48 | 52 |
|
49 | 53 | static void BM_stagedense(benchmark::State &state) { |
50 | 54 | uint horz = (uint)state.range(0); |
51 | | - VectorXs x0 = VectorXs::NullaryExpr(nx, normal_unary_op{}); |
52 | | - LqrProblemTpl<double> problem = generate_problem(x0, horz, nx, nu); |
| 55 | + VectorXs x0 = VectorXs::NullaryExpr(nx, normal_op); |
| 56 | + LqrProblemTpl<double> problem = |
| 57 | + generateLqProblem(rng, x0, horz, nx, nu, 0, nc); |
53 | 58 | RiccatiSolverDense<double> solver(problem); |
54 | | - const double mu = 1e-11; |
55 | 59 | auto [xs, us, vs, lbdas] = lqrInitializeSolution(problem); |
56 | 60 | for (auto _ : state) { |
57 | | - solver.backward(mu, mu); |
| 61 | + solver.backward(mueq); |
58 | 62 | solver.forward(xs, us, vs, lbdas); |
59 | 63 | } |
60 | 64 | } |
|
0 commit comments