Skip to content

Commit 6aa2db8

Browse files
authored
Merge branch 'main' into ce
2 parents 8af5d8e + 67b8b5e commit 6aa2db8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+1727
-2264
lines changed

CMakeLists.txt

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ set(MUJOCO_MPC_MUJOCO_GIT_TAG
6060
CACHE STRING "Git revision for MuJoCo."
6161
)
6262

63+
set(MUJOCO_MPC_MENAGERIE_GIT_TAG
64+
94ea114fa8c60a0fd542c8e1ffeb997204acbea2
65+
CACHE STRING "Git revision for MuJoCo Menagerie."
66+
)
67+
68+
set(MUJOCO_MPC_DM_CONTROL_GIT_TAG
69+
774f46182140106e22725914aad3c6299ed91edd
70+
CACHE STRING "Git revision for dm_control."
71+
)
72+
6373
findorfetch(
6474
USE_SYSTEM_PACKAGE
6575
OFF
@@ -160,14 +170,25 @@ unset(BUILD_SHARED_LIBS_OLD)
160170
FetchContent_Declare(
161171
menagerie
162172
GIT_REPOSITORY https://github.com/google-deepmind/mujoco_menagerie.git
163-
GIT_TAG main
173+
GIT_TAG ${MUJOCO_MPC_MENAGERIE_GIT_TAG}
164174
)
165175

166176
FetchContent_GetProperties(menagerie)
167177
if(NOT menagerie_POPULATED)
168178
FetchContent_Populate(menagerie)
169179
endif()
170180

181+
FetchContent_Declare(
182+
dm_control
183+
GIT_REPOSITORY https://github.com/google-deepmind/dm_control.git
184+
GIT_TAG ${MUJOCO_MPC_DM_CONTROL_GIT_TAG}
185+
)
186+
187+
FetchContent_GetProperties(dm_control)
188+
if(NOT dm_control_POPULATED)
189+
FetchContent_Populate(dm_control)
190+
endif()
191+
171192
if(NOT TARGET lodepng)
172193
FetchContent_Declare(
173194
lodepng

docs/CONTRIBUTING.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ This code adheres to the [Google style](https://google.github.io/styleguide/).
3232

3333
## New Tasks
3434

35-
When submitting a PR for a new task using models from [MuJoCo Menagerie](https://github.com/google-deepmind/mujoco_menagerie), do not include assets directly. Instead, modify the task [CMakeLists](mjpc/tasks/CMakeLists.txt) to copy these assets to the build binary.
35+
When submitting a PR for a new task that depends on third-party models, including from [MuJoCo Menagerie](https://github.com/google-deepmind/mujoco_menagerie) and [dm_control](https://github.com/google-deepmind/dm_control), do not include the xml model or assets in the task directly. Instead, modify the task [CMakeLists](mjpc/tasks/CMakeLists.txt) to copy the xml model and/or assets to the build binary.
36+
37+
If the xml model needs to be modified, create a patch that is applied in the [CMakeLists](mjpc/tasks/CMakeLists.txt). A [patch](https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/tasks/op3/op3.xml.patch) can be generated using the following command:
38+
```
39+
diff -u {original}.xml {modified}.xml > {modified}.xml.patch
40+
```
41+
The first three lines of the generated patch file will need to be be adapted for your use case. Please see an [example](https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/tasks/op3/op3.xml.patch) for a template.
3642

3743
## Unit Tests
3844

mjpc/CMakeLists.txt

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ add_library(
3838
tasks/tasks.h
3939
tasks/acrobot/acrobot.cc
4040
tasks/acrobot/acrobot.h
41+
tasks/allegro/allegro.cc
42+
tasks/allegro/allegro.h
4143
tasks/bimanual/bimanual.cc
4244
tasks/bimanual/bimanual.h
4345
tasks/cartpole/cartpole.cc
@@ -184,23 +186,33 @@ if(APPLE)
184186
target_link_libraries(mjpc "-framework Cocoa")
185187
endif()
186188

187-
add_executable(
188-
testspeed
189+
add_library(
190+
libtestspeed STATIC
189191
testspeed_app.cc
190192
testspeed.h
191193
testspeed.cc
192194
)
193195
target_link_libraries(
194-
testspeed
195-
absl::flags
196-
absl::flags_parse
196+
libtestspeed
197197
absl::random_random
198198
absl::strings
199199
libmjpc
200200
mujoco::mujoco
201201
threadpool
202202
Threads::Threads
203203
)
204+
target_include_directories(libtestspeed PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
205+
206+
add_executable(
207+
testspeed
208+
testspeed_app.cc
209+
)
210+
target_link_libraries(
211+
testspeed
212+
libtestspeed
213+
absl::flags
214+
absl::flags_parse
215+
)
204216
target_include_directories(testspeed PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
205217
target_compile_options(testspeed PUBLIC ${MJPC_COMPILE_OPTIONS})
206218
target_link_options(testspeed PRIVATE ${MJPC_LINK_OPTIONS})

mjpc/grpc/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ set(BUILD_SHARED_LIBS
2222

2323
find_package(ZLIB REQUIRED)
2424
set(gRPC_ZLIB_PROVIDER "package" CACHE INTERNAL "")
25+
set(gRPC_BUILD_GRPC_CSHARP_PLUGIN OFF)
26+
set(gRPC_BUILD_GRPC_NODE_PLUGIN OFF)
27+
set(gRPC_BUILD_GRPC_OBJECTIVE_C_PLUGIN OFF)
28+
set(gRPC_BUILD_GRPC_PHP_PLUGIN OFF)
29+
set(gRPC_BUILD_GRPC_RUBY_PLUGIN OFF)
30+
set(RE2_BUILD_TESTING OFF)
2531
set(ZLIB_BUILD_EXAMPLES OFF)
2632

2733
findorfetch(

mjpc/grpc/ui_agent_service.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,16 @@ namespace mjpc::agent_grpc {
3535

3636
using ::agent::GetActionRequest;
3737
using ::agent::GetActionResponse;
38-
using ::agent::GetModeRequest;
39-
using ::agent::GetModeResponse;
38+
using ::agent::GetAllModesRequest;
39+
using ::agent::GetAllModesResponse;
40+
using ::agent::GetBestTrajectoryRequest;
41+
using ::agent::GetBestTrajectoryResponse;
4042
using ::agent::GetResidualsRequest;
4143
using ::agent::GetResidualsResponse;
4244
using ::agent::GetCostValuesAndWeightsRequest;
4345
using ::agent::GetCostValuesAndWeightsResponse;
46+
using ::agent::GetModeRequest;
47+
using ::agent::GetModeResponse;
4448
using ::agent::GetStateRequest;
4549
using ::agent::GetStateResponse;
4650
using ::agent::GetTaskParametersRequest;
@@ -222,6 +226,24 @@ grpc::Status UiAgentService::GetMode(grpc::ServerContext* context,
222226
});
223227
}
224228

229+
grpc::Status UiAgentService::GetAllModes(grpc::ServerContext* context,
230+
const GetAllModesRequest* request,
231+
GetAllModesResponse* response) {
232+
return RunBeforeStep(
233+
context, [request, response](mjpc::Agent* agent, const mjModel* model,
234+
mjData* data) {
235+
return grpc_agent_util::GetAllModes(request, agent, response);
236+
});
237+
}
238+
239+
grpc::Status UiAgentService::GetBestTrajectory(
240+
grpc::ServerContext* context, const GetBestTrajectoryRequest* request,
241+
GetBestTrajectoryResponse* response) {
242+
// TODO - Implement.
243+
return {grpc::StatusCode::UNIMPLEMENTED,
244+
"GetBestTrajectory is not implemented."};
245+
}
246+
225247
grpc::Status UiAgentService::SetAnything(grpc::ServerContext* context,
226248
const SetAnythingRequest* request,
227249
SetAnythingResponse* response) {

mjpc/grpc/ui_agent_service.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
#ifndef MJPC_MJPC_GRPC_UI_AGENT_SERVICE_H_
1616
#define MJPC_MJPC_GRPC_UI_AGENT_SERVICE_H_
1717

18+
#include <absl/functional/any_invocable.h>
1819
#include <grpcpp/server_context.h>
1920
#include <grpcpp/support/status.h>
2021
#include <mujoco/mujoco.h>
2122

23+
#include <mjpc/agent.h>
2224
#include <mjpc/grpc/agent.grpc.pb.h>
2325
#include <mjpc/grpc/agent.pb.h>
2426
#include <mjpc/simulate.h> // mjpc fork
27+
#include <mjpc/states/state.h>
2528
#include <mjpc/utilities.h>
2629

2730
namespace mjpc::agent_grpc {
@@ -95,6 +98,15 @@ class UiAgentService final : public agent::Agent::Service {
9598
const agent::GetModeRequest* request,
9699
agent::GetModeResponse* response) override;
97100

101+
grpc::Status GetAllModes(grpc::ServerContext* context,
102+
const agent::GetAllModesRequest* request,
103+
agent::GetAllModesResponse* response) override;
104+
105+
grpc::Status GetBestTrajectory(
106+
grpc::ServerContext* context,
107+
const agent::GetBestTrajectoryRequest* request,
108+
agent::GetBestTrajectoryResponse* response) override;
109+
98110
grpc::Status SetAnything(grpc::ServerContext* context,
99111
const agent::SetAnythingRequest* request,
100112
agent::SetAnythingResponse* response) override;

mjpc/spline/spline.h

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -102,45 +102,41 @@ class TimeSpline {
102102
}
103103

104104
// Copyable, Movable.
105-
IteratorT<SplineType, NodeType>(
106-
const IteratorT<SplineType, NodeType>& other) = default;
107-
IteratorT<SplineType, NodeType>& operator=(
108-
const IteratorT<SplineType, NodeType>& other) = default;
109-
IteratorT<SplineType, NodeType>(IteratorT<SplineType, NodeType>&& other) =
110-
default;
111-
IteratorT<SplineType, NodeType>& operator=(
112-
IteratorT<SplineType, NodeType>&& other) = default;
105+
IteratorT(const IteratorT& other) = default;
106+
IteratorT& operator=(const IteratorT& other) = default;
107+
IteratorT(IteratorT&& other) = default;
108+
IteratorT& operator=(IteratorT&& other) = default;
113109

114110
reference operator*() { return node_; }
115111

116112
pointer operator->() { return &node_; }
117113
pointer operator->() const { return &node_; }
118114

119-
IteratorT<SplineType, NodeType>& operator++() {
115+
IteratorT& operator++() {
120116
++index_;
121117
node_ = index_ == spline_->Size() ? NodeType() : spline_->NodeAt(index_);
122118
return *this;
123119
}
124120

125-
IteratorT<SplineType, NodeType> operator++(int) {
126-
IteratorT<SplineType, NodeType> tmp = *this;
121+
IteratorT operator++(int) {
122+
IteratorT tmp = *this;
127123
++(*this);
128124
return tmp;
129125
}
130126

131-
IteratorT<SplineType, NodeType>& operator--() {
127+
IteratorT& operator--() {
132128
--index_;
133129
node_ = spline_->NodeAt(index_);
134130
return *this;
135131
}
136132

137-
IteratorT<SplineType, NodeType> operator--(int) {
138-
IteratorT<SplineType, NodeType> tmp = *this;
133+
IteratorT operator--(int) {
134+
IteratorT tmp = *this;
139135
--(*this);
140136
return tmp;
141137
}
142138

143-
IteratorT<SplineType, NodeType>& operator+=(difference_type n) {
139+
IteratorT& operator+=(difference_type n) {
144140
if (n != 0) {
145141
index_ += n;
146142
node_ =
@@ -149,29 +145,25 @@ class TimeSpline {
149145
return *this;
150146
}
151147

152-
IteratorT<SplineType, NodeType>& operator-=(difference_type n) {
153-
return *this += -n;
154-
}
148+
IteratorT& operator-=(difference_type n) { return *this += -n; }
155149

156-
IteratorT<SplineType, NodeType> operator+(difference_type n) const {
157-
IteratorT<SplineType, NodeType> tmp(*this);
150+
IteratorT operator+(difference_type n) const {
151+
IteratorT tmp(*this);
158152
tmp += n;
159153
return tmp;
160154
}
161155

162-
IteratorT<SplineType, NodeType> operator-(difference_type n) const {
163-
IteratorT<SplineType, NodeType> tmp(*this);
156+
IteratorT operator-(difference_type n) const {
157+
IteratorT tmp(*this);
164158
tmp -= n;
165159
return tmp;
166160
}
167161

168-
friend IteratorT<SplineType, NodeType> operator+(
169-
difference_type n, const IteratorT<SplineType, NodeType>& it) {
162+
friend IteratorT operator+(difference_type n, const IteratorT& it) {
170163
return it + n;
171164
}
172165

173-
friend difference_type operator-(const IteratorT<SplineType, NodeType>& x,
174-
const IteratorT<SplineType, NodeType>& y) {
166+
friend difference_type operator-(const IteratorT& x, const IteratorT& y) {
175167
CHECK_EQ(x.spline_, y.spline_)
176168
<< "Comparing iterators from different splines";
177169
if (x != y) return (x.index_ - y.index_);
@@ -180,35 +172,29 @@ class TimeSpline {
180172

181173
NodeType operator[](difference_type n) const { return *(*this + n); }
182174

183-
friend bool operator==(const IteratorT<SplineType, NodeType>& x,
184-
const IteratorT<SplineType, NodeType>& y) {
175+
friend bool operator==(const IteratorT& x, const IteratorT& y) {
185176
return x.spline_ == y.spline_ && x.index_ == y.index_;
186177
}
187178

188-
friend bool operator!=(const IteratorT<SplineType, NodeType>& x,
189-
const IteratorT<SplineType, NodeType>& y) {
179+
friend bool operator!=(const IteratorT& x, const IteratorT& y) {
190180
return !(x == y);
191181
}
192182

193-
friend bool operator<(const IteratorT<SplineType, NodeType>& x,
194-
const IteratorT<SplineType, NodeType>& y) {
183+
friend bool operator<(const IteratorT& x, const IteratorT& y) {
195184
CHECK_EQ(x.spline_, y.spline_)
196185
<< "Comparing iterators from different splines";
197186
return x.index_ < y.index_;
198187
}
199188

200-
friend bool operator>(const IteratorT<SplineType, NodeType>& x,
201-
const IteratorT<SplineType, NodeType>& y) {
189+
friend bool operator>(const IteratorT& x, const IteratorT& y) {
202190
return y < x;
203191
}
204192

205-
friend bool operator<=(const IteratorT<SplineType, NodeType>& x,
206-
const IteratorT<SplineType, NodeType>& y) {
193+
friend bool operator<=(const IteratorT& x, const IteratorT& y) {
207194
return !(y < x);
208195
}
209196

210-
friend bool operator>=(const IteratorT<SplineType, NodeType>& x,
211-
const IteratorT<SplineType, NodeType>& y) {
197+
friend bool operator>=(const IteratorT& x, const IteratorT& y) {
212198
return !(x < y);
213199
}
214200

0 commit comments

Comments
 (0)