Skip to content

Commit 30df6f9

Browse files
authored
Merge branch 'main' into sample_gradient_v2
2 parents f22e80c + 1ecabed commit 30df6f9

29 files changed

+3149
-24
lines changed

.github/workflows/build.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@ jobs:
1919
-DCMAKE_C_COMPILER:STRING=clang-12
2020
-DCMAKE_CXX_COMPILER:STRING=clang++-12
2121
-DMJPC_BUILD_GRPC_SERVICE:BOOL=ON
22-
additional_targets: "agent_server direct_server filter_server"
2322
tmpdir: "/tmp"
2423
- os: macos-12
2524
cmake_args: >-
2625
-G Ninja
2726
-DMJPC_BUILD_GRPC_SERVICE:BOOL=ON
28-
additional_targets: "agent_server direct_server filter_server"
2927
tmpdir: "/tmp"
3028

3129
name: "MuJoCo MPC on ${{ matrix.os }} ${{ matrix.additional_label }}"
@@ -43,9 +41,10 @@ jobs:
4341
libxrandr-dev
4442
libxi-dev
4543
ninja-build
44+
zlib1g-dev
4645
- name: Prepare macOS
4746
if: ${{ runner.os == 'macOS' }}
48-
run: brew install ninja
47+
run: brew install ninja zlib
4948
- name: Prepare Windows
5049
if: ${{ runner.os == 'Windows' }}
5150
# Install llvm 16 manually, remove after
@@ -68,10 +67,12 @@ jobs:
6867
$cmake_extra_args
6968
- name: Build MuJoCo MPC
7069
working-directory: build
71-
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }} --target mjpc agent_test agent_utilities_test cost_derivatives_test norm_test rollout_test threadpool_test trajectory_test direct_force_test direct_optimize_test direct_parameter_test direct_sensor_test direct_trajectory_test direct_utilities_test batch_filter_test batch_prior_test kalman_test unscented_test cubic_test gradient_planner_test gradient_test linear_test zero_test backward_pass_test ilqg_test robust_planner_test sampling_planner_test state_test task_test utilities_test ${{ matrix.additional_targets }}
70+
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }}
7271
- name: Test MuJoCo MPC
7372
working-directory: build
74-
run: ctest -C Release --output-on-failure .
73+
run: >
74+
cd mjpc/test &&
75+
ctest -C Release --output-on-failure .
7576
- name: Notify team chat
7677
shell: bash
7778
env:

CMakeLists.txt

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,12 @@ findorfetch(
7777
EXCLUDE_FROM_ALL
7878
)
7979

80+
# TODO(nimrod): Update to the latest version of abseil, or use the one defined
81+
# by MuJoCo, once grpc fix their build issues.
8082
set(MUJOCO_DEP_VERSION_abseil
81-
c8a2f92586fe9b4e1aff049108f5db8064924d8e # LTS 20230125.1
83+
fb3621f4f897824c0dbe0615fa94543df6192f30 # LTS 20230802.1
8284
CACHE STRING "Version of `abseil` to be fetched."
83-
)
84-
85-
set(MUJOCO_DEP_VERSION_glfw3
86-
7482de6071d21db77a7236155da44c172a7f6c9e # 3.3.8
87-
CACHE STRING "Version of `glfw` to be fetched."
88-
)
89-
90-
set(MJPC_DEP_VERSION_lodepng
91-
b4ed2cd7ecf61d29076169b49199371456d4f90b
92-
CACHE STRING "Version of `lodepng` to be fetched."
85+
FORCE
9386
)
9487

9588
set(BUILD_SHARED_LIBS_OLD ${BUILD_SHARED_LIBS})
@@ -118,6 +111,9 @@ findorfetch(
118111

119112
set(ABSL_PROPAGATE_CXX_STD ON)
120113
set(ABSL_BUILD_TESTING OFF)
114+
# ABSL_ENABLE_INSTALL is needed for
115+
# https://github.com/protocolbuffers/protobuf/issues/12185#issuecomment-1594685860
116+
set(ABSL_ENABLE_INSTALL ON)
121117
findorfetch(
122118
USE_SYSTEM_PACKAGE
123119
OFF

mjpc/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ add_library(
3838
tasks/tasks.h
3939
tasks/acrobot/acrobot.cc
4040
tasks/acrobot/acrobot.h
41+
tasks/bimanual/bimanual.cc
42+
tasks/bimanual/bimanual.h
4143
tasks/cartpole/cartpole.cc
4244
tasks/cartpole/cartpole.h
4345
tasks/cube/solve.cc
@@ -122,6 +124,8 @@ add_library(
122124
direct/trajectory.h
123125
direct/model_parameters.cc
124126
direct/model_parameters.h
127+
spline/spline.cc
128+
spline/spline.h
125129
app.cc
126130
app.h
127131
norm.cc
@@ -138,8 +142,11 @@ target_compile_definitions(libmjpc PRIVATE MJSIMULATE_STATIC)
138142
target_link_libraries(
139143
libmjpc
140144
absl::any_invocable
145+
absl::check
141146
absl::flat_hash_map
147+
absl::log
142148
absl::random_random
149+
absl::span
143150
glfw
144151
lodepng
145152
mujoco::mujoco

mjpc/grpc/CMakeLists.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ set(BUILD_SHARED_LIBS
2020
CACHE INTERNAL "Build SHARED libraries"
2121
)
2222

23+
find_package(ZLIB REQUIRED)
24+
set(gRPC_ZLIB_PROVIDER "package" CACHE INTERNAL "")
25+
set(ZLIB_BUILD_EXAMPLES OFF)
26+
2327
findorfetch(
2428
USE_SYSTEM_PACKAGE
2529
OFF
@@ -30,14 +34,11 @@ findorfetch(
3034
GIT_REPO
3135
https://github.com/grpc/grpc
3236
GIT_TAG
33-
v1.53.0
37+
v1.60.1
3438
TARGETS
3539
gRPC
3640
)
3741

38-
find_package(ZLIB REQUIRED)
39-
set(gRPC_ZLIB_PROVIDER "package" CACHE INTERNAL "")
40-
set(ZLIB_BUILD_EXAMPLES OFF)
4142
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
4243
set(_REFLECTION grpc++_reflection)
4344
set(_PROTOBUF_PROTOC $<TARGET_FILE:protoc>)

mjpc/grpc/agent.proto

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ service Agent {
3939
returns (GetTaskParametersResponse);
4040
// Set cost weights.
4141
rpc SetCostWeights(SetCostWeightsRequest) returns (SetCostWeightsResponse);
42+
// Get cost term residuals.
43+
rpc GetResiduals(GetResidualsRequest) returns (GetResidualsResponse);
4244
// Get cost term values.
4345
rpc GetCostValuesAndWeights(GetCostValuesAndWeightsRequest)
4446
returns (GetCostValuesAndWeightsResponse);
@@ -113,6 +115,16 @@ message GetActionResponse {
113115
repeated float action = 1 [packed = true];
114116
}
115117

118+
message GetResidualsRequest {}
119+
120+
message Residual {
121+
repeated double values = 1;
122+
}
123+
124+
message GetResidualsResponse {
125+
map<string, Residual> values = 1;
126+
}
127+
116128
message GetCostValuesAndWeightsRequest {}
117129

118130
message ValueAndWeight {

mjpc/grpc/agent_service.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ using ::agent::GetAllModesRequest;
3535
using ::agent::GetAllModesResponse;
3636
using ::agent::GetBestTrajectoryRequest;
3737
using ::agent::GetBestTrajectoryResponse;
38+
using ::agent::GetResidualsRequest;
39+
using ::agent::GetResidualsResponse;
3840
using ::agent::GetCostValuesAndWeightsRequest;
3941
using ::agent::GetCostValuesAndWeightsResponse;
4042
using ::agent::GetModeRequest;
@@ -181,6 +183,16 @@ grpc::Status AgentService::GetAction(grpc::ServerContext* context,
181183
request, &agent_, model, rollout_data_.get(), &rollout_state_, response);
182184
}
183185

186+
grpc::Status AgentService::GetResiduals(
187+
grpc::ServerContext* context, const GetResidualsRequest* request,
188+
GetResidualsResponse* response) {
189+
if (!Initialized()) {
190+
return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."};
191+
}
192+
return grpc_agent_util::GetResiduals(request, &agent_, model,
193+
data_, response);
194+
}
195+
184196
grpc::Status AgentService::GetCostValuesAndWeights(
185197
grpc::ServerContext* context, const GetCostValuesAndWeightsRequest* request,
186198
GetCostValuesAndWeightsResponse* response) {

mjpc/grpc/agent_service.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class AgentService final : public agent::Agent::Service {
5858
const agent::GetActionRequest* request,
5959
agent::GetActionResponse* response) override;
6060

61+
grpc::Status GetResiduals(
62+
grpc::ServerContext* context,
63+
const agent::GetResidualsRequest* request,
64+
agent::GetResidualsResponse* response) override;
65+
6166
grpc::Status GetCostValuesAndWeights(
6267
grpc::ServerContext* context,
6368
const agent::GetCostValuesAndWeightsRequest* request,

mjpc/grpc/agent_service_test.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,4 +360,16 @@ TEST_F(AgentServiceTest, GetAllModes_Works) {
360360
EXPECT_EQ(response.mode_names()[0], "default_mode");
361361
}
362362

363+
TEST_F(AgentServiceTest, GetResiduals_Works) {
364+
RunAndCheckInit("Cartpole", nullptr);
365+
366+
grpc::ClientContext context;
367+
368+
agent::GetResidualsRequest request;
369+
agent::GetResidualsResponse response;
370+
grpc::Status status = stub->GetResiduals(&context, request, &response);
371+
372+
EXPECT_TRUE(status.ok());
373+
}
374+
363375
} // namespace mjpc::agent_grpc

mjpc/grpc/grpc_agent_util.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ using ::agent::GetActionRequest;
4444
using ::agent::GetActionResponse;
4545
using ::agent::GetAllModesRequest;
4646
using ::agent::GetAllModesResponse;
47+
using ::agent::GetResidualsRequest;
48+
using ::agent::GetResidualsResponse;
4749
using ::agent::GetCostValuesAndWeightsRequest;
4850
using ::agent::GetCostValuesAndWeightsResponse;
4951
using ::agent::GetModeRequest;
@@ -58,6 +60,7 @@ using ::agent::SetCostWeightsRequest;
5860
using ::agent::SetModeRequest;
5961
using ::agent::SetStateRequest;
6062
using ::agent::SetTaskParametersRequest;
63+
using ::agent::Residual;
6164
using ::agent::ValueAndWeight;
6265

6366
grpc::Status GetState(const mjModel* model, const mjData* data,
@@ -226,6 +229,34 @@ grpc::Status GetAction(const GetActionRequest* request,
226229
return grpc::Status::OK;
227230
}
228231

232+
grpc::Status GetResiduals(
233+
const GetResidualsRequest* request, const mjpc::Agent* agent,
234+
const mjModel* model, mjData* data,
235+
GetResidualsResponse* response) {
236+
const mjModel* agent_model = agent->GetModel();
237+
const mjpc::Task* task = agent->ActiveTask();
238+
std::vector<double> residuals(task->num_residual, 0); // scratch space
239+
task->Residual(model, data, residuals.data());
240+
std::vector<int> dim_norm_residual = task->dim_norm_residual;
241+
242+
int residual_shift = 0;
243+
for (int i = 0; i < task->num_term; i++) {
244+
CHECK_EQ(agent_model->sensor_type[i], mjSENS_USER);
245+
std::string_view sensor_name(agent_model->names +
246+
agent_model->name_sensoradr[i]);
247+
248+
std::vector<double> sensor_residual_values(
249+
residuals.begin() + residual_shift,
250+
residuals.begin() + residual_shift + dim_norm_residual[i]);
251+
Residual sensor_residual;
252+
sensor_residual.mutable_values()->Assign(sensor_residual_values.begin(),
253+
sensor_residual_values.end());
254+
(*response->mutable_values())[sensor_name] = sensor_residual;
255+
residual_shift += dim_norm_residual[i];
256+
}
257+
return grpc::Status::OK;
258+
}
259+
229260
grpc::Status GetCostValuesAndWeights(
230261
const GetCostValuesAndWeightsRequest* request, const mjpc::Agent* agent,
231262
const mjModel* model, mjData* data,

mjpc/grpc/grpc_agent_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ grpc::Status GetAction(const agent::GetActionRequest* request,
3434
const mjModel* model, mjData* rollout_data,
3535
mjpc::State* rollout_state,
3636
agent::GetActionResponse* response);
37+
grpc::Status GetResiduals(
38+
const agent::GetResidualsRequest* request,
39+
const mjpc::Agent* agent, const mjModel* model, mjData* data,
40+
agent::GetResidualsResponse* response);
3741
grpc::Status GetCostValuesAndWeights(
3842
const agent::GetCostValuesAndWeightsRequest* request,
3943
const mjpc::Agent* agent, const mjModel* model, mjData* data,

0 commit comments

Comments
 (0)