Skip to content

Commit 7eaf68d

Browse files
authored
Merge branch 'main' into sample_gradient_v2
2 parents c423cdb + dad4944 commit 7eaf68d

File tree

9 files changed

+254
-7
lines changed

9 files changed

+254
-7
lines changed

mjpc/CMakeLists.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,28 @@ if(APPLE)
177177
target_link_libraries(mjpc "-framework Cocoa")
178178
endif()
179179

180+
add_executable(
181+
testspeed
182+
testspeed_app.cc
183+
testspeed.h
184+
testspeed.cc
185+
)
186+
target_link_libraries(
187+
testspeed
188+
absl::flags
189+
absl::flags_parse
190+
absl::random_random
191+
absl::strings
192+
libmjpc
193+
mujoco::mujoco
194+
threadpool
195+
Threads::Threads
196+
)
197+
target_include_directories(testspeed PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
198+
target_compile_options(testspeed PUBLIC ${MJPC_COMPILE_OPTIONS})
199+
target_link_options(testspeed PRIVATE ${MJPC_LINK_OPTIONS})
200+
target_compile_definitions(testspeed PRIVATE MJSIMULATE_STATIC)
201+
180202
add_subdirectory(tasks)
181203

182204
if(BUILD_TESTING AND MJPC_BUILD_TESTS)

mjpc/tasks/cube/cube_3x3x3.xml.patch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ diff --git a/cube_3x3x3_modified.xml b/cube_3x3x3_modified.xml
2323
<default class="cubelet">
2424
- <joint type="ball" armature="0.0001" damping="0.0005" frictionloss="0.001"/>
2525
- <geom type="mesh" condim="1" mesh="cubelet" euler="0 0 90"/>
26-
+ <joint type="ball" armature="0.00005" damping="0.0001" frictionloss="0.00005"/>
26+
+ <joint type="ball" armature="0.0001" damping="0.0005" frictionloss="0.00005"/>
2727
+ <geom type="mesh" condim="1" mesh="cubelet" quat="1 0 0 1"/>
2828
</default>
2929
<default class="core">

mjpc/tasks/cube/solve.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mjpc/tasks/cube/solve.h"
1616

1717
#include <algorithm>
18+
#include <iostream>
1819
#include <random>
1920
#include <string>
2021

@@ -131,6 +132,7 @@ void CubeSolve::ResidualFn::Residual(const mjModel* model, const mjData* data,
131132
void CubeSolve::TransitionLocked(mjModel* model, mjData* data) {
132133
if (transition_model_) {
133134
if (mode == kModeWait) {
135+
weight[11] = .01; // add penalty on joint movement
134136
// wait
135137
} else if (mode == kModeScramble) { // scramble
136138
double scramble_param = parameters[6];
@@ -190,9 +192,11 @@ void CubeSolve::TransitionLocked(mjModel* model, mjData* data) {
190192

191193
// set face goal index
192194
goal_index_ = num_scramble - 1;
195+
std::cout << "rotations required: " << num_scramble << "\n";
193196

194197
// set to solve
195198
mode = kModeSolve;
199+
weight[11] = 0; // remove penalty on joint movement
196200
} else if (mode == kModeSolve) { // solve
197201
// set goal
198202
mju_copy(parameters.data(), goal_cache_.data() + 6 * goal_index_, 6);
@@ -204,7 +208,9 @@ void CubeSolve::TransitionLocked(mjModel* model, mjData* data) {
204208
if (mju_norm(error, 6) < 0.085) {
205209
if (goal_index_ == 0) {
206210
mode = kModeWait;
211+
std::cout << "solved!\n";
207212
} else {
213+
std::cout << "rotations remaining: " << goal_index_ << "\n";
208214
goal_index_--;
209215
}
210216
}
@@ -213,11 +219,10 @@ void CubeSolve::TransitionLocked(mjModel* model, mjData* data) {
213219

214220
// check for drop
215221
if (data->qpos[6] < kResetHeight) {
216-
// reset cube position + orientation
217-
mju_copy(data->qpos, model->key_qpos, 7);
222+
if (mode != kModeWait) { std::cout << "cube fell\n"; }
218223

219-
// reset cube velocity
220-
mju_zero(data->qvel, 6);
224+
// stop optimization
225+
mode = kModeWait;
221226
}
222227

223228
// check goal index

mjpc/tasks/cube/task.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
<numeric name="agent_policy_width" data="0.0035" />
1313
<numeric name="sampling_spline_points" data="6" />
1414
<numeric name="sampling_exploration" data="0.1" />
15-
<numeric name="sampling_trajectories" data="20" />
16-
<numeric name="sampling_representation" data="0" />
15+
<numeric name="sampling_trajectories" data="60" />
16+
<numeric name="sampling_representation" data="1" />
1717
<!-- manual face goals -->
1818
<numeric name="residual_Red" data="0 -3.14 3.14"/>
1919
<numeric name="residual_Orange" data="0 -3.14 3.14"/>

mjpc/tasks/tasks.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <memory>
1818
#include <vector>
1919

20+
#include "mjpc/task.h"
2021
#include "mjpc/tasks/acrobot/acrobot.h"
2122
#include "mjpc/tasks/cube/solve.h"
2223
#include "mjpc/tasks/cartpole/cartpole.h"

mjpc/testspeed.cc

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "mjpc/testspeed.h"
16+
17+
#include <chrono>
18+
#include <cmath>
19+
#include <iostream>
20+
#include <string>
21+
#include <vector>
22+
23+
#include <mujoco/mujoco.h>
24+
25+
#include "mjpc/agent.h"
26+
#include "mjpc/states/state.h"
27+
#include "mjpc/task.h"
28+
#include "mjpc/threadpool.h"
29+
#include "mjpc/utilities.h"
30+
#include "mjpc/tasks/tasks.h"
31+
32+
namespace mjpc {
33+
34+
namespace {
35+
Task* task;
36+
void residual_callback(const mjModel* model, mjData* data, int stage) {
37+
if (stage == mjSTAGE_ACC) {
38+
task->Residual(model, data, data->sensordata);
39+
}
40+
}
41+
} // namespace
42+
43+
// Run synchronous planning, print timing info,return 0 if nothing failed.
44+
int TestSpeed(std::string task_name, int planner_thread_count,
45+
int steps_per_planning_iteration, double total_time) {
46+
std::cout << "Test MJPC Speed\n";
47+
std::cout << " MuJoCo version " << mj_versionString() << "\n";
48+
if (mjVERSION_HEADER != mj_version()) {
49+
mju_error("Headers and library have Different versions");
50+
}
51+
std::cout << " Hardware threads: " << NumAvailableHardwareThreads() << "\n";
52+
53+
Agent agent;
54+
agent.SetTaskList(GetTasks());
55+
agent.gui_task_id = agent.GetTaskIdByName(task_name);
56+
if (agent.gui_task_id == -1) {
57+
std::cerr << "Invalid --task flag: '" << task_name
58+
<< "'. Valid values:\n";
59+
std::cerr << agent.GetTaskNames();
60+
return -1;
61+
}
62+
auto load_model = agent.LoadModel();
63+
mjModel* model = load_model.model.release();
64+
if (!model) {
65+
std::cerr << load_model.error << "\n";
66+
return 1;
67+
}
68+
mjData* data = mj_makeData(model);
69+
mj_forward(model, data);
70+
71+
int home_id = mj_name2id(model, mjOBJ_KEY, "home");
72+
if (home_id >= 0) {
73+
std::cout << "home_id: " << home_id << "\n";
74+
mj_resetDataKeyframe(model, data, home_id);
75+
}
76+
77+
// the planner and its initial configuration is set in the XML
78+
agent.estimator_enabled = false;
79+
agent.Initialize(model);
80+
agent.Allocate();
81+
agent.Reset(data->ctrl);
82+
agent.plan_enabled = true;
83+
84+
// make task available for global callback:
85+
task = agent.ActiveTask();
86+
mjcb_sensor = &residual_callback;
87+
88+
std::cout << " Planning threads: " << planner_thread_count << "\n";
89+
ThreadPool pool(planner_thread_count);
90+
91+
int total_steps = ceil(total_time / model->opt.timestep);
92+
int current_time = 0;
93+
double total_cost = 0;
94+
auto loop_start = std::chrono::steady_clock::now();
95+
for (int i = 0; i < total_steps; i++) {
96+
agent.ActiveTask()->Transition(model, data);
97+
agent.state.Set(model, data);
98+
99+
agent.ActivePlanner().ActionFromPolicy(
100+
data->ctrl, agent.state.state().data(),
101+
agent.state.time(), /*use_previous=*/false);
102+
mj_step(model, data);
103+
double cost = agent.ActiveTask()->CostValue(data->sensordata);
104+
total_cost += cost;
105+
106+
if (i % steps_per_planning_iteration == 0) { agent.PlanIteration(&pool); }
107+
108+
if (floor(data->time) > current_time) {
109+
current_time++;
110+
std::cout << "sim time: " << current_time << ", cost: " << cost << "\n";
111+
}
112+
}
113+
auto wall_run_time = std::chrono::duration_cast<std::chrono::microseconds>(
114+
std::chrono::steady_clock::now() - loop_start)
115+
.count() /
116+
1e6;
117+
std::cout << "Total wall time ("
118+
<< (int)ceil(total_steps / steps_per_planning_iteration)
119+
<< " planning steps): " << wall_run_time << " s ("
120+
<< total_time / wall_run_time << "x realtime)\n";
121+
std::cout << "Average cost per step (lower is better): "
122+
<< total_cost / total_steps << "\n";
123+
124+
mj_deleteData(data);
125+
mj_deleteModel(model);
126+
return 0;
127+
}
128+
} // namespace mjpc

mjpc/testspeed.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef MJPC_MJPC_TESTSPEED_H_
16+
#define MJPC_MJPC_TESTSPEED_H_
17+
18+
#include <string>
19+
20+
namespace mjpc {
21+
int TestSpeed(std::string task_name, int planner_thread_count,
22+
int steps_per_planning_iteration, double total_time);
23+
} // namespace mjpc
24+
25+
#endif // MJPC_MJPC_TESTSPEED_H_

mjpc/testspeed_app.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
17+
#include <absl/flags/parse.h>
18+
#include <absl/flags/flag.h>
19+
20+
#include "mjpc/testspeed.h"
21+
#include "mjpc/utilities.h"
22+
23+
ABSL_FLAG(std::string, task, "Cube Solving", "Which model to load on startup.");
24+
ABSL_FLAG(int, planner_thread, mjpc::NumAvailableHardwareThreads() - 5,
25+
"Number of planner threads to use.");
26+
ABSL_FLAG(int, steps_per_planning_iteration, 4,
27+
"How many physics steps to take between planning iterations.");
28+
ABSL_FLAG(double, total_time, 10, "Total time to simulate (seconds).");
29+
30+
int main(int argc, char** argv) {
31+
absl::ParseCommandLine(argc, argv);
32+
std::string task_name = absl::GetFlag(FLAGS_task);
33+
int planner_thread_count = absl::GetFlag(FLAGS_planner_thread);
34+
int steps_per_planning_iteration =
35+
absl::GetFlag(FLAGS_steps_per_planning_iteration);
36+
double total_time = absl::GetFlag(FLAGS_total_time);
37+
return mjpc::TestSpeed(task_name, planner_thread_count,
38+
steps_per_planning_iteration, total_time);
39+
}

mjpc/testspeed_test.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "mjpc/testspeed.h"
16+
17+
#include "gtest/gtest.h"
18+
19+
namespace {
20+
21+
TEST(TestSeed, Test) {
22+
EXPECT_EQ(
23+
mjpc::TestSpeed("Cartpole", /*planner_thread_count=*/10,
24+
/*steps_per_planning_iteration=*/10, /*total_time=*/10),
25+
0);
26+
}
27+
} // namespace

0 commit comments

Comments
 (0)