Skip to content

Commit 75a7a86

Browse files
authored
Merge branch 'main' into ps_py
2 parents 26091ef + 0f1c648 commit 75a7a86

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1687
-99
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ real-time predictive control with [MuJoCo](https://mujoco.org/), developed by
1616
Google DeepMind.
1717

1818
MJPC allows the user to easily author and solve complex robotics tasks, and
19-
currently supports three shooting-based planners: derivative-based iLQG and
20-
Gradient Descent, and a simple yet very competitive derivative-free method
21-
called Predictive Sampling.
19+
currently supports multiple shooting-based planners. Derivative-based methods include iLQG and
20+
Gradient Descent, while derivative-free methods include a simple yet very competitive planner
21+
called Predictive Sampling and the Cross Entropy Method (with diagonal covariance).
2222

2323
- [Overview](#overview)
2424
- [Graphical User Interface](#graphical-user-interface)

cmake/MujocoLinkOptions.cmake

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function(get_mujoco_extra_link_options OUTPUT_VAR)
2323
set(EXTRA_LINK_OPTIONS)
2424

2525
if(WIN32)
26-
set(CMAKE_REQUIRED_FLAGS "-fuse-ld=lld-link")
26+
set(CMAKE_REQUIRED_LINK_OPTIONS "-fuse-ld=lld-link")
2727
check_c_source_compiles("int main() {}" SUPPORTS_LLD)
2828
if(SUPPORTS_LLD)
2929
set(EXTRA_LINK_OPTIONS
@@ -34,24 +34,24 @@ function(get_mujoco_extra_link_options OUTPUT_VAR)
3434
)
3535
endif()
3636
else()
37-
set(CMAKE_REQUIRED_FLAGS "-fuse-ld=lld")
37+
set(CMAKE_REQUIRED_LINK_OPTIONS "-fuse-ld=lld")
3838
check_c_source_compiles("int main() {}" SUPPORTS_LLD)
3939
if(SUPPORTS_LLD)
4040
set(EXTRA_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} -fuse-ld=lld)
4141
else()
42-
set(CMAKE_REQUIRED_FLAGS "-fuse-ld=gold")
42+
set(CMAKE_REQUIRED_LINK_OPTIONS "-fuse-ld=gold")
4343
check_c_source_compiles("int main() {}" SUPPORTS_GOLD)
4444
if(SUPPORTS_GOLD)
4545
set(EXTRA_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} -fuse-ld=gold)
4646
endif()
4747
endif()
4848

49-
set(CMAKE_REQUIRED_FLAGS ${EXTRA_LINK_OPTIONS} "-Wl,--gc-sections")
49+
set(CMAKE_REQUIRED_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} "-Wl,--gc-sections")
5050
check_c_source_compiles("int main() {}" SUPPORTS_GC_SECTIONS)
5151
if(SUPPORTS_GC_SECTIONS)
5252
set(EXTRA_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} -Wl,--gc-sections)
5353
else()
54-
set(CMAKE_REQUIRED_FLAGS ${EXTRA_LINK_OPTIONS} "-Wl,-dead_strip")
54+
set(CMAKE_REQUIRED_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} "-Wl,-dead_strip")
5555
check_c_source_compiles("int main() {}" SUPPORTS_DEAD_STRIP)
5656
if(SUPPORTS_DEAD_STRIP)
5757
set(EXTRA_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} -Wl,-dead_strip)

docs/OVERVIEW.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ Values in brackets should be replaced by the designer.
132132

133133
`Agent` settings can be specified by prepending `agent_` for corresponding class members.
134134

135-
`Planner` settings can similarly be specified by prepending the corresponding optimizer name, (e.g., `sampling_`, `gradient_`, `ilqg_`).
135+
`Planner` settings can similarly be specified by prepending the corresponding optimizer name, (e.g., `sampling_`, `cross_entropy_`, `gradient_`, `ilqg_`).
136136

137137
It is also possible to create GUI elements for parameters that are passed to the residual function. These are specified by the prefix `residual_`, when the suffix will be the display name of the slider:
138138

@@ -280,12 +280,15 @@ Additionally, custom labeled buttons can be added to the GUI by specifying a str
280280

281281
The purpose of `Planner` is to find improved policies using numerical optimization.
282282

283-
This library includes three planners that use different techniques to perform this search:
283+
This library includes multiple planners that use different techniques to perform this search:
284284

285285
- **Predictive Sampling**
286286
- random search
287287
- derivative free
288288
- spline representation for controls
289+
- **Cross Entropy Method**
290+
- all properties of Predictive Sampling
291+
- refits a nominal policy to mean of elite samples instead of using the best
289292
- **Gradient Descent**
290293
- requires gradients
291294
- spline representation for controls

mjpc/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ add_library(
4040
tasks/acrobot/acrobot.h
4141
tasks/cartpole/cartpole.cc
4242
tasks/cartpole/cartpole.h
43+
tasks/cube/solve.cc
44+
tasks/cube/solve.h
4345
tasks/fingers/fingers.cc
4446
tasks/fingers/fingers.h
4547
tasks/hand/hand.cc
@@ -54,6 +56,8 @@ add_library(
5456
tasks/manipulation/common.h
5557
tasks/manipulation/manipulation.cc
5658
tasks/manipulation/manipulation.h
59+
tasks/op3/stand.cc
60+
tasks/op3/stand.h
5761
tasks/panda/panda.cc
5862
tasks/panda/panda.h
5963
tasks/particle/particle.cc
@@ -75,6 +79,8 @@ add_library(
7579
planners/cost_derivatives.h
7680
planners/model_derivatives.cc
7781
planners/model_derivatives.h
82+
planners/cross_entropy/planner.cc
83+
planners/cross_entropy/planner.h
7884
planners/robust/robust_planner.cc
7985
planners/robust/robust_planner.h
8086
planners/sampling/planner.cc

mjpc/agent.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ void Agent::Initialize(const mjModel* model) {
105105
state.Initialize(model);
106106

107107
// initialize estimator
108-
if (reset_estimator) {
108+
if (reset_estimator && estimator_enabled) {
109109
for (const auto& estimator : estimators_) {
110110
estimator->Initialize(model_);
111111
estimator->Reset();
@@ -169,7 +169,7 @@ void Agent::Reset(const double* initial_repeated_action) {
169169
state.Reset();
170170

171171
// estimator
172-
if (reset_estimator) {
172+
if (reset_estimator && estimator_enabled) {
173173
for (const auto& estimator : estimators_) {
174174
estimator->Reset();
175175
}

mjpc/direct/direct.cc

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,6 @@ void Direct::Initialize(const mjModel* model) {
148148
act.Initialize(na, configuration_length_);
149149
times.Initialize(1, configuration_length_);
150150

151-
// ctrl
152-
ctrl.Initialize(model->nu, configuration_length_);
153-
154151
// prior
155152
configuration_previous.Initialize(nq, configuration_length_);
156153

@@ -348,7 +345,6 @@ void Direct::Reset(const mjData* data) {
348345
acceleration.Reset();
349346
act.Reset();
350347
times.Reset();
351-
ctrl.Reset();
352348

353349
// prior
354350
configuration_previous.Reset();
@@ -637,8 +633,6 @@ void Direct::SetConfigurationLength(int length) {
637633
act.SetLength(configuration_length_);
638634
times.SetLength(configuration_length_);
639635

640-
ctrl.SetLength(configuration_length_);
641-
642636
configuration_previous.SetLength(configuration_length_);
643637

644638
sensor_measurement.SetLength(configuration_length_);
@@ -1489,8 +1483,7 @@ void Direct::InverseDynamicsPrediction() {
14891483
auto start = std::chrono::steady_clock::now();
14901484

14911485
// dimension
1492-
int nq = model->nq, nv = model->nv, na = model->na, nu = model->nu,
1493-
ns = nsensordata_;
1486+
int nq = model->nq, nv = model->nv, na = model->na, ns = nsensordata_;
14941487

14951488
// set parameters
14961489
if (nparam_ > 0) {
@@ -1502,7 +1495,7 @@ void Direct::InverseDynamicsPrediction() {
15021495
int count_before = pool_.GetCount();
15031496

15041497
// first time step
1505-
pool_.Schedule([&batch = *this, nq, nv, nu]() {
1498+
pool_.Schedule([&batch = *this, nq, nv]() {
15061499
// time index
15071500
int t = 0;
15081501

@@ -1518,7 +1511,6 @@ void Direct::InverseDynamicsPrediction() {
15181511
mju_copy(d->qpos, q0, nq);
15191512
mju_zero(d->qvel, nv);
15201513
mju_zero(d->qacc, nv);
1521-
mju_zero(d->ctrl, nu);
15221514
d->time = batch.times.Get(t)[0];
15231515

15241516
// position sensors
@@ -1551,12 +1543,11 @@ void Direct::InverseDynamicsPrediction() {
15511543
// loop over predictions
15521544
for (int t = 1; t < configuration_length_ - 1; t++) {
15531545
// schedule
1554-
pool_.Schedule([&batch = *this, nq, nv, na, ns, nu, t]() {
1546+
pool_.Schedule([&batch = *this, nq, nv, na, ns, t]() {
15551547
// terms
15561548
double* qt = batch.configuration.Get(t);
15571549
double* vt = batch.velocity.Get(t);
15581550
double* at = batch.acceleration.Get(t);
1559-
double* ct = batch.ctrl.Get(t);
15601551

15611552
// data
15621553
mjData* d = batch.data_[t].get();
@@ -1565,7 +1556,6 @@ void Direct::InverseDynamicsPrediction() {
15651556
mju_copy(d->qpos, qt, nq);
15661557
mju_copy(d->qvel, vt, nv);
15671558
mju_copy(d->qacc, at, nv);
1568-
mju_copy(d->ctrl, ct, nu);
15691559

15701560
// inverse dynamics
15711561
mj_inverse(batch.model, d);
@@ -1585,7 +1575,7 @@ void Direct::InverseDynamicsPrediction() {
15851575
}
15861576

15871577
// last time step
1588-
pool_.Schedule([&batch = *this, nq, nv, nu]() {
1578+
pool_.Schedule([&batch = *this, nq, nv]() {
15891579
// time index
15901580
int t = batch.ConfigurationLength() - 1;
15911581

@@ -1602,7 +1592,6 @@ void Direct::InverseDynamicsPrediction() {
16021592
mju_copy(d->qpos, qT, nq);
16031593
mju_copy(d->qvel, vT, nv);
16041594
mju_zero(d->qacc, nv);
1605-
mju_zero(d->ctrl, nu);
16061595
d->time = batch.times.Get(t)[0];
16071596

16081597
// position sensors
@@ -1653,7 +1642,7 @@ void Direct::InverseDynamicsDerivatives() {
16531642
auto start = std::chrono::steady_clock::now();
16541643

16551644
// dimension
1656-
int nq = model->nq, nv = model->nv, nu = model->nu;
1645+
int nq = model->nq, nv = model->nv;
16571646

16581647
// set parameters
16591648
if (nparam_ > 0) {
@@ -1665,7 +1654,7 @@ void Direct::InverseDynamicsDerivatives() {
16651654
int count_before = pool_.GetCount();
16661655

16671656
// first time step
1668-
pool_.Schedule([&batch = *this, nq, nv, nu]() {
1657+
pool_.Schedule([&batch = *this, nq, nv]() {
16691658
// time index
16701659
int t = 0;
16711660

@@ -1680,7 +1669,6 @@ void Direct::InverseDynamicsDerivatives() {
16801669
mju_copy(d->qpos, q0, nq);
16811670
mju_zero(d->qvel, nv);
16821671
mju_zero(d->qacc, nv);
1683-
mju_zero(d->ctrl, nu);
16841672
d->time = batch.times.Get(t)[0];
16851673

16861674
// finite-difference derivatives
@@ -1725,12 +1713,11 @@ void Direct::InverseDynamicsDerivatives() {
17251713
// loop over predictions
17261714
for (int t = 1; t < configuration_length_ - 1; t++) {
17271715
// schedule
1728-
pool_.Schedule([&batch = *this, nq, nv, nu, t]() {
1716+
pool_.Schedule([&batch = *this, nq, nv, t]() {
17291717
// unpack
17301718
double* q = batch.configuration.Get(t);
17311719
double* v = batch.velocity.Get(t);
17321720
double* a = batch.acceleration.Get(t);
1733-
double* c = batch.ctrl.Get(t);
17341721

17351722
double* dsdq = batch.block_sensor_configuration_.Get(t);
17361723
double* dsdv = batch.block_sensor_velocity_.Get(t);
@@ -1743,11 +1730,10 @@ void Direct::InverseDynamicsDerivatives() {
17431730
double* dadf = batch.block_force_acceleration_.Get(t);
17441731
mjData* data = batch.data_[t].get(); // TODO(taylor): WorkerID
17451732

1746-
// set (state, acceleration) + ctrl
1733+
// set state, acceleration
17471734
mju_copy(data->qpos, q, nq);
17481735
mju_copy(data->qvel, v, nv);
17491736
mju_copy(data->qacc, a, nv);
1750-
mju_copy(data->ctrl, c, nu);
17511737

17521738
// finite-difference derivatives
17531739
mjd_inverseFD(batch.model, data, batch.finite_difference.tolerance,
@@ -1767,7 +1753,7 @@ void Direct::InverseDynamicsDerivatives() {
17671753
}
17681754

17691755
// last time step
1770-
pool_.Schedule([&batch = *this, nq, nv, nu]() {
1756+
pool_.Schedule([&batch = *this, nq, nv]() {
17711757
// time index
17721758
int t = batch.ConfigurationLength() - 1;
17731759

@@ -1784,7 +1770,6 @@ void Direct::InverseDynamicsDerivatives() {
17841770
mju_copy(d->qpos, qT, nq);
17851771
mju_copy(d->qvel, vT, nv);
17861772
mju_zero(d->qacc, nv);
1787-
mju_zero(d->ctrl, nu);
17881773
d->time = batch.times.Get(t)[0];
17891774

17901775
// finite-difference derivatives
@@ -2061,7 +2046,7 @@ double Direct::Cost(double* gradient, double* hessian) {
20612046
// set dense rows in band Hessian
20622047
if (hessian) {
20632048
mju_copy(hessian + nvel_ * nband_, dense_parameter_.data(),
2064-
nparam_ * ntotal_);
2049+
nparam_ * ntotal_);
20652050
}
20662051
}
20672052

mjpc/direct/direct.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ class Direct {
161161
DirectTrajectory<double> acceleration; // nv x T
162162
DirectTrajectory<double> act; // na x T
163163
DirectTrajectory<double> times; // 1 x T
164-
DirectTrajectory<double> ctrl; // nu x T
165164
DirectTrajectory<double> sensor_measurement; // ns x T
166165
DirectTrajectory<double> sensor_prediction; // ns x T
167166
DirectTrajectory<int> sensor_mask; // num_sensor x T

mjpc/estimators/batch.cc

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,6 @@ void Batch::Initialize(const mjModel* model) {
119119
act_cache_.Initialize(na, max_history_);
120120
times_cache_.Initialize(1, max_history_);
121121

122-
// ctrl
123-
ctrl_cache_.Initialize(model->nu, max_history_);
124-
125122
// prior
126123
configuration_previous_cache_.Initialize(nq, max_history_);
127124

@@ -241,7 +238,6 @@ void Batch::Reset(const mjData* data) {
241238
acceleration_cache_.Reset();
242239
act_cache_.Reset();
243240
times_cache_.Reset();
244-
ctrl_cache_.Reset();
245241

246242
// prior
247243
configuration_previous_cache_.Reset();
@@ -325,9 +321,6 @@ void Batch::Update(const double* ctrl, const double* sensor) {
325321
// set next time
326322
times.Set(&d->time, t + 1);
327323

328-
// set ctrl
329-
this->ctrl.Set(ctrl, t);
330-
331324
// set sensor
332325
sensor_measurement.Set(sensor + sensor_start_index_, t);
333326

@@ -527,7 +520,6 @@ void Batch::Shift(int shift) {
527520
acceleration.Shift(shift);
528521
act.Shift(shift);
529522
times.Shift(shift);
530-
ctrl.Shift(shift);
531523

532524
configuration_previous.Shift(shift);
533525

@@ -815,7 +807,6 @@ void Batch::ShiftResizeTrajectory(int new_head, int new_length) {
815807
acceleration_cache_.Reset();
816808
act_cache_.Reset();
817809
times_cache_.Reset();
818-
ctrl_cache_.Reset();
819810
sensor_measurement_cache_.Reset();
820811
sensor_prediction_cache_.Reset();
821812
sensor_mask_cache_.Reset();
@@ -831,7 +822,6 @@ void Batch::ShiftResizeTrajectory(int new_head, int new_length) {
831822
acceleration_cache_.SetLength(length);
832823
act_cache_.SetLength(length);
833824
times_cache_.SetLength(length);
834-
ctrl_cache_.SetLength(length);
835825
sensor_measurement_cache_.SetLength(length);
836826
sensor_prediction_cache_.SetLength(length);
837827
sensor_mask_cache_.SetLength(length);
@@ -846,7 +836,6 @@ void Batch::ShiftResizeTrajectory(int new_head, int new_length) {
846836
acceleration_cache_.Set(acceleration.Get(i), i);
847837
act_cache_.Set(act.Get(i), i);
848838
times_cache_.Set(times.Get(i), i);
849-
ctrl_cache_.Set(ctrl.Get(i), i);
850839
sensor_measurement_cache_.Set(sensor_measurement.Get(i), i);
851840
sensor_prediction_cache_.Set(sensor_prediction.Get(i), i);
852841
sensor_mask_cache_.Set(sensor_mask.Get(i), i);
@@ -867,7 +856,6 @@ void Batch::ShiftResizeTrajectory(int new_head, int new_length) {
867856
acceleration.Set(acceleration_cache_.Get(new_head + i), i);
868857
act.Set(act_cache_.Get(new_head + i), i);
869858
times.Set(times_cache_.Get(new_head + i), i);
870-
ctrl.Set(ctrl_cache_.Get(new_head + i), i);
871859
sensor_measurement.Set(sensor_measurement_cache_.Get(new_head + i), i);
872860
sensor_prediction.Set(sensor_prediction_cache_.Get(new_head + i), i);
873861
sensor_mask.Set(sensor_mask_cache_.Get(new_head + i), i);

mjpc/estimators/batch.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
namespace mjpc {
3333

3434
// max filter history
35-
inline constexpr int kMaxFilterHistory = 128;
35+
inline constexpr int kMaxFilterHistory = 64;
3636

3737
// ----- batch estimator ----- //
3838
// based on: "Physically-Consistent Sensor Fusion in Contact-Rich Behaviors"
@@ -239,7 +239,6 @@ class Batch : public Direct, public Estimator {
239239
DirectTrajectory<double> acceleration_cache_; // nv x T
240240
DirectTrajectory<double> act_cache_; // na x T
241241
DirectTrajectory<double> times_cache_; // 1 x T
242-
DirectTrajectory<double> ctrl_cache_; // nu x T
243242
DirectTrajectory<double> sensor_measurement_cache_; // ns x T
244243
DirectTrajectory<double> sensor_prediction_cache_; // ns x T
245244
DirectTrajectory<int> sensor_mask_cache_; // num_sensor x T

mjpc/grpc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ findorfetch(
3737

3838
find_package(ZLIB REQUIRED)
3939
set(gRPC_ZLIB_PROVIDER "package" CACHE INTERNAL "")
40-
40+
set(ZLIB_BUILD_EXAMPLES OFF)
4141
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
4242
set(_REFLECTION grpc++_reflection)
4343
set(_PROTOBUF_PROTOC $<TARGET_FILE:protoc>)

0 commit comments

Comments
 (0)