Skip to content

Commit cd86fbf

Browse files
Merge pull request #279 from google-deepmind/deepmind
Merge deepmind branch into main.
2 parents dba4bd8 + 32309a8 commit cd86fbf

29 files changed

+1415
-31
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
$cmake_extra_args
6969
- name: Build MuJoCo MPC
7070
working-directory: build
71-
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }} --target mjpc agent_test agent_utilities_test cost_derivatives_test norm_test rollout_test threadpool_test trajectory_test direct_force_test direct_optimize_test direct_parameter_test direct_sensor_test direct_trajectory_test direct_utilities_test batch_filter_test batch_prior_test kalman_test unscented_test cubic_test gradient_planner_test gradient_test linear_test zero_test backward_pass_test ilqg_test robust_planner_test sampling_planner_test state_test task_test utilities_test ${{ matrix.additional_targets }}
71+
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }} --target mjpc agent_test agent_utilities_test cost_derivatives_test norm_test rollout_test threadpool_test trajectory_test direct_force_test direct_optimize_test direct_parameter_test direct_sensor_test direct_trajectory_test direct_utilities_test batch_filter_test batch_prior_test kalman_test spline_test unscented_test cubic_test gradient_planner_test gradient_test linear_test zero_test backward_pass_test ilqg_test robust_planner_test sampling_planner_test state_test task_test utilities_test ${{ matrix.additional_targets }}
7272
- name: Test MuJoCo MPC
7373
working-directory: build
7474
run: ctest -C Release --output-on-failure .

CMakeLists.txt

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,12 @@ findorfetch(
7777
EXCLUDE_FROM_ALL
7878
)
7979

80+
# TODO(nimrod): Update to the latest version of abseil, or use the one defined
81+
# by MuJoCo, once grpc fix their build issues.
8082
set(MUJOCO_DEP_VERSION_abseil
81-
c8a2f92586fe9b4e1aff049108f5db8064924d8e # LTS 20230125.1
83+
fb3621f4f897824c0dbe0615fa94543df6192f30 # LTS 20230802.1
8284
CACHE STRING "Version of `abseil` to be fetched."
83-
)
84-
85-
set(MUJOCO_DEP_VERSION_glfw3
86-
7482de6071d21db77a7236155da44c172a7f6c9e # 3.3.8
87-
CACHE STRING "Version of `glfw` to be fetched."
88-
)
89-
90-
set(MJPC_DEP_VERSION_lodepng
91-
b4ed2cd7ecf61d29076169b49199371456d4f90b
92-
CACHE STRING "Version of `lodepng` to be fetched."
85+
FORCE
9386
)
9487

9588
set(BUILD_SHARED_LIBS_OLD ${BUILD_SHARED_LIBS})
@@ -118,6 +111,9 @@ findorfetch(
118111

119112
set(ABSL_PROPAGATE_CXX_STD ON)
120113
set(ABSL_BUILD_TESTING OFF)
114+
# ABSL_ENABLE_INSTALL is needed for
115+
# https://github.com/protocolbuffers/protobuf/issues/12185#issuecomment-1594685860
116+
set(ABSL_ENABLE_INSTALL ON)
121117
findorfetch(
122118
USE_SYSTEM_PACKAGE
123119
OFF

mjpc/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ add_library(
3838
tasks/tasks.h
3939
tasks/acrobot/acrobot.cc
4040
tasks/acrobot/acrobot.h
41+
tasks/bimanual/bimanual.cc
42+
tasks/bimanual/bimanual.h
4143
tasks/cartpole/cartpole.cc
4244
tasks/cartpole/cartpole.h
4345
tasks/cube/solve.cc
@@ -122,6 +124,8 @@ add_library(
122124
direct/trajectory.h
123125
direct/model_parameters.cc
124126
direct/model_parameters.h
127+
spline/spline.cc
128+
spline/spline.h
125129
app.cc
126130
app.h
127131
norm.cc
@@ -138,8 +142,11 @@ target_compile_definitions(libmjpc PRIVATE MJSIMULATE_STATIC)
138142
target_link_libraries(
139143
libmjpc
140144
absl::any_invocable
145+
absl::check
141146
absl::flat_hash_map
147+
absl::log
142148
absl::random_random
149+
absl::span
143150
glfw
144151
lodepng
145152
mujoco::mujoco

mjpc/grpc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ findorfetch(
3030
GIT_REPO
3131
https://github.com/grpc/grpc
3232
GIT_TAG
33-
v1.53.0
33+
v1.60.1
3434
TARGETS
3535
gRPC
3636
)

mjpc/grpc/agent.proto

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ service Agent {
3939
returns (GetTaskParametersResponse);
4040
// Set cost weights.
4141
rpc SetCostWeights(SetCostWeightsRequest) returns (SetCostWeightsResponse);
42+
// Get cost term residuals.
43+
rpc GetResiduals(GetResidualsRequest) returns (GetResidualsResponse);
4244
// Get cost term values.
4345
rpc GetCostValuesAndWeights(GetCostValuesAndWeightsRequest)
4446
returns (GetCostValuesAndWeightsResponse);
@@ -113,6 +115,16 @@ message GetActionResponse {
113115
repeated float action = 1 [packed = true];
114116
}
115117

118+
message GetResidualsRequest {}
119+
120+
message Residual {
121+
repeated double values = 1;
122+
}
123+
124+
message GetResidualsResponse {
125+
map<string, Residual> values = 1;
126+
}
127+
116128
message GetCostValuesAndWeightsRequest {}
117129

118130
message ValueAndWeight {

mjpc/grpc/agent_service.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ using ::agent::GetAllModesRequest;
3535
using ::agent::GetAllModesResponse;
3636
using ::agent::GetBestTrajectoryRequest;
3737
using ::agent::GetBestTrajectoryResponse;
38+
using ::agent::GetResidualsRequest;
39+
using ::agent::GetResidualsResponse;
3840
using ::agent::GetCostValuesAndWeightsRequest;
3941
using ::agent::GetCostValuesAndWeightsResponse;
4042
using ::agent::GetModeRequest;
@@ -181,6 +183,16 @@ grpc::Status AgentService::GetAction(grpc::ServerContext* context,
181183
request, &agent_, model, rollout_data_.get(), &rollout_state_, response);
182184
}
183185

186+
grpc::Status AgentService::GetResiduals(
187+
grpc::ServerContext* context, const GetResidualsRequest* request,
188+
GetResidualsResponse* response) {
189+
if (!Initialized()) {
190+
return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."};
191+
}
192+
return grpc_agent_util::GetResiduals(request, &agent_, model,
193+
data_, response);
194+
}
195+
184196
grpc::Status AgentService::GetCostValuesAndWeights(
185197
grpc::ServerContext* context, const GetCostValuesAndWeightsRequest* request,
186198
GetCostValuesAndWeightsResponse* response) {

mjpc/grpc/agent_service.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class AgentService final : public agent::Agent::Service {
5858
const agent::GetActionRequest* request,
5959
agent::GetActionResponse* response) override;
6060

61+
grpc::Status GetResiduals(
62+
grpc::ServerContext* context,
63+
const agent::GetResidualsRequest* request,
64+
agent::GetResidualsResponse* response) override;
65+
6166
grpc::Status GetCostValuesAndWeights(
6267
grpc::ServerContext* context,
6368
const agent::GetCostValuesAndWeightsRequest* request,

mjpc/grpc/agent_service_test.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,4 +360,16 @@ TEST_F(AgentServiceTest, GetAllModes_Works) {
360360
EXPECT_EQ(response.mode_names()[0], "default_mode");
361361
}
362362

363+
TEST_F(AgentServiceTest, GetResiduals_Works) {
364+
RunAndCheckInit("Cartpole", nullptr);
365+
366+
grpc::ClientContext context;
367+
368+
agent::GetResidualsRequest request;
369+
agent::GetResidualsResponse response;
370+
grpc::Status status = stub->GetResiduals(&context, request, &response);
371+
372+
EXPECT_TRUE(status.ok());
373+
}
374+
363375
} // namespace mjpc::agent_grpc

mjpc/grpc/grpc_agent_util.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ using ::agent::GetActionRequest;
4444
using ::agent::GetActionResponse;
4545
using ::agent::GetAllModesRequest;
4646
using ::agent::GetAllModesResponse;
47+
using ::agent::GetResidualsRequest;
48+
using ::agent::GetResidualsResponse;
4749
using ::agent::GetCostValuesAndWeightsRequest;
4850
using ::agent::GetCostValuesAndWeightsResponse;
4951
using ::agent::GetModeRequest;
@@ -58,6 +60,7 @@ using ::agent::SetCostWeightsRequest;
5860
using ::agent::SetModeRequest;
5961
using ::agent::SetStateRequest;
6062
using ::agent::SetTaskParametersRequest;
63+
using ::agent::Residual;
6164
using ::agent::ValueAndWeight;
6265

6366
grpc::Status GetState(const mjModel* model, const mjData* data,
@@ -226,6 +229,34 @@ grpc::Status GetAction(const GetActionRequest* request,
226229
return grpc::Status::OK;
227230
}
228231

232+
grpc::Status GetResiduals(
233+
const GetResidualsRequest* request, const mjpc::Agent* agent,
234+
const mjModel* model, mjData* data,
235+
GetResidualsResponse* response) {
236+
const mjModel* agent_model = agent->GetModel();
237+
const mjpc::Task* task = agent->ActiveTask();
238+
std::vector<double> residuals(task->num_residual, 0); // scratch space
239+
task->Residual(model, data, residuals.data());
240+
std::vector<int> dim_norm_residual = task->dim_norm_residual;
241+
242+
int residual_shift = 0;
243+
for (int i = 0; i < task->num_term; i++) {
244+
CHECK_EQ(agent_model->sensor_type[i], mjSENS_USER);
245+
std::string_view sensor_name(agent_model->names +
246+
agent_model->name_sensoradr[i]);
247+
248+
std::vector<double> sensor_residual_values(
249+
residuals.begin() + residual_shift,
250+
residuals.begin() + residual_shift + dim_norm_residual[i]);
251+
Residual sensor_residual;
252+
sensor_residual.mutable_values()->Assign(sensor_residual_values.begin(),
253+
sensor_residual_values.end());
254+
(*response->mutable_values())[sensor_name] = sensor_residual;
255+
residual_shift += dim_norm_residual[i];
256+
}
257+
return grpc::Status::OK;
258+
}
259+
229260
grpc::Status GetCostValuesAndWeights(
230261
const GetCostValuesAndWeightsRequest* request, const mjpc::Agent* agent,
231262
const mjModel* model, mjData* data,

mjpc/grpc/grpc_agent_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ grpc::Status GetAction(const agent::GetActionRequest* request,
3434
const mjModel* model, mjData* rollout_data,
3535
mjpc::State* rollout_state,
3636
agent::GetActionResponse* response);
37+
grpc::Status GetResiduals(
38+
const agent::GetResidualsRequest* request,
39+
const mjpc::Agent* agent, const mjModel* model, mjData* data,
40+
agent::GetResidualsResponse* response);
3741
grpc::Status GetCostValuesAndWeights(
3842
const agent::GetCostValuesAndWeightsRequest* request,
3943
const mjpc::Agent* agent, const mjModel* model, mjData* data,

0 commit comments

Comments
 (0)