Skip to content

Commit 29e2d4e

Browse files
erez-tomcopybara-github
authored andcommitted
add an executable for testing planning speed.
PiperOrigin-RevId: 602733632 Change-Id: I7b46435c65ba1ca94dec26197c598ff8f5aa2635
1 parent 38a6638 commit 29e2d4e

File tree

5 files changed

+242
-0
lines changed

5 files changed

+242
-0
lines changed

mjpc/CMakeLists.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,28 @@ if(APPLE)
175175
target_link_libraries(mjpc "-framework Cocoa")
176176
endif()
177177

178+
add_executable(
179+
testspeed
180+
testspeed_app.cc
181+
testspeed.h
182+
testspeed.cc
183+
)
184+
target_link_libraries(
185+
testspeed
186+
absl::flags
187+
absl::flags_parse
188+
absl::random_random
189+
absl::strings
190+
libmjpc
191+
mujoco::mujoco
192+
threadpool
193+
Threads::Threads
194+
)
195+
target_include_directories(testspeed PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
196+
target_compile_options(testspeed PUBLIC ${MJPC_COMPILE_OPTIONS})
197+
target_link_options(testspeed PRIVATE ${MJPC_LINK_OPTIONS})
198+
target_compile_definitions(testspeed PRIVATE MJSIMULATE_STATIC)
199+
178200
add_subdirectory(tasks)
179201

180202
if(BUILD_TESTING AND MJPC_BUILD_TESTS)

mjpc/testspeed.cc

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "third_party/mujoco_mpc/mjpc/testspeed.h"
16+
17+
#include <chrono>
18+
#include <cmath>
19+
#include <cstdio>
20+
#include <iostream>
21+
#include <string>
22+
#include <vector>
23+
24+
#include <mujoco/mujoco.h>
25+
26+
#include "mjpc/agent.h"
27+
#include "mjpc/states/state.h"
28+
#include "mjpc/task.h"
29+
#include "mjpc/threadpool.h"
30+
#include "mjpc/utilities.h"
31+
#include "mjpc/tasks/tasks.h"
32+
33+
namespace mjpc {
34+
35+
namespace {
36+
Task* task;
37+
void residual_callback(const mjModel* model, mjData* data, int stage) {
38+
if (stage == mjSTAGE_ACC) {
39+
task->Residual(model, data, data->sensordata);
40+
}
41+
}
42+
} // namespace
43+
44+
// Run synchronous planning, print timing info,return 0 if nothing failed.
45+
int TestSpeed(std::string task_name, int planner_thread_count,
46+
int steps_per_planning_iteration, double total_time) {
47+
std::cout << "Test MJPC Speed\n";
48+
std::cout << " MuJoCo version " << mj_versionString() << "\n";
49+
if (mjVERSION_HEADER != mj_version()) {
50+
mju_error("Headers and library have Different versions");
51+
}
52+
std::cout << " Hardware threads: " << NumAvailableHardwareThreads() << "\n";
53+
54+
Agent agent;
55+
agent.SetTaskList(GetTasks());
56+
agent.gui_task_id = agent.GetTaskIdByName(task_name);
57+
if (agent.gui_task_id == -1) {
58+
std::cerr << "Invalid --task flag: '" << task_name
59+
<< "'. Valid values:\n";
60+
std::cerr << agent.GetTaskNames();
61+
return -1;
62+
}
63+
auto load_model = agent.LoadModel();
64+
mjModel* model = load_model.model.release();
65+
if (!model) {
66+
std::cerr << load_model.error << "\n";
67+
return 1;
68+
}
69+
mjData* data = mj_makeData(model);
70+
mj_forward(model, data);
71+
72+
int home_id = mj_name2id(model, mjOBJ_KEY, "home");
73+
if (home_id >= 0) {
74+
std::cout << "home_id: " << home_id << "\n";
75+
mj_resetDataKeyframe(model, data, home_id);
76+
}
77+
78+
// the planner and its initial configuration is set in the XML
79+
agent.estimator_enabled = false;
80+
agent.Initialize(model);
81+
agent.Allocate();
82+
agent.Reset(data->ctrl);
83+
agent.plan_enabled = true;
84+
85+
// make task available for global callback:
86+
task = agent.ActiveTask();
87+
mjcb_sensor = &residual_callback;
88+
89+
std::cout << " Planning threads: " << planner_thread_count << "\n";
90+
ThreadPool pool(planner_thread_count);
91+
92+
int total_steps = ceil(total_time / model->opt.timestep);
93+
int current_time = 0;
94+
double total_cost = 0;
95+
auto loop_start = std::chrono::steady_clock::now();
96+
for (int i = 0; i < total_steps; i++) {
97+
agent.ActiveTask()->Transition(model, data);
98+
agent.state.Set(model, data);
99+
100+
agent.ActivePlanner().ActionFromPolicy(
101+
data->ctrl, agent.state.state().data(),
102+
agent.state.time(), /*use_previous=*/false);
103+
mj_step(model, data);
104+
double cost = agent.ActiveTask()->CostValue(data->sensordata);
105+
total_cost += cost;
106+
107+
if (i % steps_per_planning_iteration == 0) { agent.PlanIteration(&pool); }
108+
109+
if (floor(data->time) > current_time) {
110+
current_time++;
111+
std::cout << "sim time: " << current_time << ", cost: " << cost << "\n";
112+
}
113+
}
114+
auto wall_run_time = std::chrono::duration_cast<std::chrono::microseconds>(
115+
std::chrono::steady_clock::now() - loop_start)
116+
.count() /
117+
1e6;
118+
std::cout << "Total wall time ("
119+
<< (int)ceil(total_steps / steps_per_planning_iteration)
120+
<< " planning steps): " << wall_run_time << " s ("
121+
<< total_time / wall_run_time << "x realtime)\n";
122+
std::cout << "Average cost per step (lower is better): "
123+
<< total_cost / total_steps << "\n";
124+
125+
mj_deleteData(data);
126+
mj_deleteModel(model);
127+
return 0;
128+
}
129+
} // namespace mjpc

mjpc/testspeed.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef MJPC_MJPC_TESTSPEED_H_
16+
#define MJPC_MJPC_TESTSPEED_H_
17+
18+
#include <string>
19+
20+
namespace mjpc {
21+
int TestSpeed(std::string task_name, int planner_thread_count,
22+
int steps_per_planning_iteration, double total_time);
23+
} // namespace mjpc
24+
25+
#endif // MJPC_MJPC_TESTSPEED_H_

mjpc/testspeed_app.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
17+
#include "base/init_google.h"
18+
#include <absl/flags/flag.h>
19+
20+
#include "third_party/mujoco_mpc/mjpc/testspeed.h"
21+
#include "mjpc/utilities.h"
22+
23+
ABSL_FLAG(std::string, task, "Cube Solving", "Which model to load on startup.");
24+
ABSL_FLAG(int, planner_thread, mjpc::NumAvailableHardwareThreads() - 5,
25+
"Number of planner threads to use.");
26+
ABSL_FLAG(int, steps_per_planning_iteration, 4,
27+
"How many physics steps to take between planning iterations.");
28+
ABSL_FLAG(double, total_time, 10, "Total time to simulate (seconds).");
29+
30+
int main(int argc, char** argv) {
31+
InitGoogle(argv[0], &argc, &argv, true);
32+
std::string task_name = absl::GetFlag(FLAGS_task);
33+
int planner_thread_count = absl::GetFlag(FLAGS_planner_thread);
34+
int steps_per_planning_iteration =
35+
absl::GetFlag(FLAGS_steps_per_planning_iteration);
36+
double total_time = absl::GetFlag(FLAGS_total_time);
37+
return mjpc::TestSpeed(task_name, planner_thread_count,
38+
steps_per_planning_iteration, total_time);
39+
}

mjpc/testspeed_test.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "third_party/mujoco_mpc/mjpc/testspeed.h"
16+
17+
#include "testing/base/public/gunit.h"
18+
19+
namespace {
20+
21+
TEST(TestSeed, Test) {
22+
EXPECT_EQ(
23+
mjpc::TestSpeed("Cartpole", /*planner_thread_count=*/10,
24+
/*steps_per_planning_iteration=*/10, /*total_time=*/10),
25+
0);
26+
}
27+
} // namespace

0 commit comments

Comments
 (0)