Skip to content

Commit 7c421b8

Browse files
Merge of 97ab026
PiperOrigin-RevId: 601710890 Change-Id: If7fbe394ca443f015317a2af69e2033b9280e1b8
2 parents 4b8f1b4 + 97ab026 commit 7c421b8

File tree

15 files changed

+713
-90
lines changed

15 files changed

+713
-90
lines changed

mjpc/direct/direct.cc

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,6 @@ void Direct::Initialize(const mjModel* model) {
148148
act.Initialize(na, configuration_length_);
149149
times.Initialize(1, configuration_length_);
150150

151-
// ctrl
152-
ctrl.Initialize(model->nu, configuration_length_);
153-
154151
// prior
155152
configuration_previous.Initialize(nq, configuration_length_);
156153

@@ -348,7 +345,6 @@ void Direct::Reset(const mjData* data) {
348345
acceleration.Reset();
349346
act.Reset();
350347
times.Reset();
351-
ctrl.Reset();
352348

353349
// prior
354350
configuration_previous.Reset();
@@ -637,8 +633,6 @@ void Direct::SetConfigurationLength(int length) {
637633
act.SetLength(configuration_length_);
638634
times.SetLength(configuration_length_);
639635

640-
ctrl.SetLength(configuration_length_);
641-
642636
configuration_previous.SetLength(configuration_length_);
643637

644638
sensor_measurement.SetLength(configuration_length_);
@@ -1489,8 +1483,7 @@ void Direct::InverseDynamicsPrediction() {
14891483
auto start = std::chrono::steady_clock::now();
14901484

14911485
// dimension
1492-
int nq = model->nq, nv = model->nv, na = model->na, nu = model->nu,
1493-
ns = nsensordata_;
1486+
int nq = model->nq, nv = model->nv, na = model->na, ns = nsensordata_;
14941487

14951488
// set parameters
14961489
if (nparam_ > 0) {
@@ -1502,7 +1495,7 @@ void Direct::InverseDynamicsPrediction() {
15021495
int count_before = pool_.GetCount();
15031496

15041497
// first time step
1505-
pool_.Schedule([&batch = *this, nq, nv, nu]() {
1498+
pool_.Schedule([&batch = *this, nq, nv]() {
15061499
// time index
15071500
int t = 0;
15081501

@@ -1518,7 +1511,6 @@ void Direct::InverseDynamicsPrediction() {
15181511
mju_copy(d->qpos, q0, nq);
15191512
mju_zero(d->qvel, nv);
15201513
mju_zero(d->qacc, nv);
1521-
mju_zero(d->ctrl, nu);
15221514
d->time = batch.times.Get(t)[0];
15231515

15241516
// position sensors
@@ -1551,12 +1543,11 @@ void Direct::InverseDynamicsPrediction() {
15511543
// loop over predictions
15521544
for (int t = 1; t < configuration_length_ - 1; t++) {
15531545
// schedule
1554-
pool_.Schedule([&batch = *this, nq, nv, na, ns, nu, t]() {
1546+
pool_.Schedule([&batch = *this, nq, nv, na, ns, t]() {
15551547
// terms
15561548
double* qt = batch.configuration.Get(t);
15571549
double* vt = batch.velocity.Get(t);
15581550
double* at = batch.acceleration.Get(t);
1559-
double* ct = batch.ctrl.Get(t);
15601551

15611552
// data
15621553
mjData* d = batch.data_[t].get();
@@ -1565,7 +1556,6 @@ void Direct::InverseDynamicsPrediction() {
15651556
mju_copy(d->qpos, qt, nq);
15661557
mju_copy(d->qvel, vt, nv);
15671558
mju_copy(d->qacc, at, nv);
1568-
mju_copy(d->ctrl, ct, nu);
15691559

15701560
// inverse dynamics
15711561
mj_inverse(batch.model, d);
@@ -1585,7 +1575,7 @@ void Direct::InverseDynamicsPrediction() {
15851575
}
15861576

15871577
// last time step
1588-
pool_.Schedule([&batch = *this, nq, nv, nu]() {
1578+
pool_.Schedule([&batch = *this, nq, nv]() {
15891579
// time index
15901580
int t = batch.ConfigurationLength() - 1;
15911581

@@ -1602,7 +1592,6 @@ void Direct::InverseDynamicsPrediction() {
16021592
mju_copy(d->qpos, qT, nq);
16031593
mju_copy(d->qvel, vT, nv);
16041594
mju_zero(d->qacc, nv);
1605-
mju_zero(d->ctrl, nu);
16061595
d->time = batch.times.Get(t)[0];
16071596

16081597
// position sensors
@@ -1653,7 +1642,7 @@ void Direct::InverseDynamicsDerivatives() {
16531642
auto start = std::chrono::steady_clock::now();
16541643

16551644
// dimension
1656-
int nq = model->nq, nv = model->nv, nu = model->nu;
1645+
int nq = model->nq, nv = model->nv;
16571646

16581647
// set parameters
16591648
if (nparam_ > 0) {
@@ -1665,7 +1654,7 @@ void Direct::InverseDynamicsDerivatives() {
16651654
int count_before = pool_.GetCount();
16661655

16671656
// first time step
1668-
pool_.Schedule([&batch = *this, nq, nv, nu]() {
1657+
pool_.Schedule([&batch = *this, nq, nv]() {
16691658
// time index
16701659
int t = 0;
16711660

@@ -1680,7 +1669,6 @@ void Direct::InverseDynamicsDerivatives() {
16801669
mju_copy(d->qpos, q0, nq);
16811670
mju_zero(d->qvel, nv);
16821671
mju_zero(d->qacc, nv);
1683-
mju_zero(d->ctrl, nu);
16841672
d->time = batch.times.Get(t)[0];
16851673

16861674
// finite-difference derivatives
@@ -1725,12 +1713,11 @@ void Direct::InverseDynamicsDerivatives() {
17251713
// loop over predictions
17261714
for (int t = 1; t < configuration_length_ - 1; t++) {
17271715
// schedule
1728-
pool_.Schedule([&batch = *this, nq, nv, nu, t]() {
1716+
pool_.Schedule([&batch = *this, nq, nv, t]() {
17291717
// unpack
17301718
double* q = batch.configuration.Get(t);
17311719
double* v = batch.velocity.Get(t);
17321720
double* a = batch.acceleration.Get(t);
1733-
double* c = batch.ctrl.Get(t);
17341721

17351722
double* dsdq = batch.block_sensor_configuration_.Get(t);
17361723
double* dsdv = batch.block_sensor_velocity_.Get(t);
@@ -1743,11 +1730,10 @@ void Direct::InverseDynamicsDerivatives() {
17431730
double* dadf = batch.block_force_acceleration_.Get(t);
17441731
mjData* data = batch.data_[t].get(); // TODO(taylor): WorkerID
17451732

1746-
// set (state, acceleration) + ctrl
1733+
// set state, acceleration
17471734
mju_copy(data->qpos, q, nq);
17481735
mju_copy(data->qvel, v, nv);
17491736
mju_copy(data->qacc, a, nv);
1750-
mju_copy(data->ctrl, c, nu);
17511737

17521738
// finite-difference derivatives
17531739
mjd_inverseFD(batch.model, data, batch.finite_difference.tolerance,
@@ -1767,7 +1753,7 @@ void Direct::InverseDynamicsDerivatives() {
17671753
}
17681754

17691755
// last time step
1770-
pool_.Schedule([&batch = *this, nq, nv, nu]() {
1756+
pool_.Schedule([&batch = *this, nq, nv]() {
17711757
// time index
17721758
int t = batch.ConfigurationLength() - 1;
17731759

@@ -1784,7 +1770,6 @@ void Direct::InverseDynamicsDerivatives() {
17841770
mju_copy(d->qpos, qT, nq);
17851771
mju_copy(d->qvel, vT, nv);
17861772
mju_zero(d->qacc, nv);
1787-
mju_zero(d->ctrl, nu);
17881773
d->time = batch.times.Get(t)[0];
17891774

17901775
// finite-difference derivatives
@@ -2061,7 +2046,7 @@ double Direct::Cost(double* gradient, double* hessian) {
20612046
// set dense rows in band Hessian
20622047
if (hessian) {
20632048
mju_copy(hessian + nvel_ * nband_, dense_parameter_.data(),
2064-
nparam_ * ntotal_);
2049+
nparam_ * ntotal_);
20652050
}
20662051
}
20672052

mjpc/direct/direct.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#define MJPC_DIRECT_OPTIMIZER_H_
1717

1818
#include <memory>
19-
#include <mutex>
2019
#include <string>
2120
#include <vector>
2221

@@ -161,7 +160,6 @@ class Direct {
161160
DirectTrajectory<double> acceleration; // nv x T
162161
DirectTrajectory<double> act; // na x T
163162
DirectTrajectory<double> times; // 1 x T
164-
DirectTrajectory<double> ctrl; // nu x T
165163
DirectTrajectory<double> sensor_measurement; // ns x T
166164
DirectTrajectory<double> sensor_prediction; // ns x T
167165
DirectTrajectory<int> sensor_mask; // num_sensor x T

mjpc/estimators/batch.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include "mjpc/array_safety.h"
2525
#include "mjpc/estimators/estimator.h"
2626
#include "mjpc/direct/direct.h"
27-
#include "mjpc/norm.h"
2827
#include "mjpc/threadpool.h"
2928
#include "mjpc/utilities.h"
3029

@@ -119,9 +118,6 @@ void Batch::Initialize(const mjModel* model) {
119118
act_cache_.Initialize(na, max_history_);
120119
times_cache_.Initialize(1, max_history_);
121120

122-
// ctrl
123-
ctrl_cache_.Initialize(model->nu, max_history_);
124-
125121
// prior
126122
configuration_previous_cache_.Initialize(nq, max_history_);
127123

@@ -241,7 +237,6 @@ void Batch::Reset(const mjData* data) {
241237
acceleration_cache_.Reset();
242238
act_cache_.Reset();
243239
times_cache_.Reset();
244-
ctrl_cache_.Reset();
245240

246241
// prior
247242
configuration_previous_cache_.Reset();
@@ -325,9 +320,6 @@ void Batch::Update(const double* ctrl, const double* sensor) {
325320
// set next time
326321
times.Set(&d->time, t + 1);
327322

328-
// set ctrl
329-
this->ctrl.Set(ctrl, t);
330-
331323
// set sensor
332324
sensor_measurement.Set(sensor + sensor_start_index_, t);
333325

@@ -527,7 +519,6 @@ void Batch::Shift(int shift) {
527519
acceleration.Shift(shift);
528520
act.Shift(shift);
529521
times.Shift(shift);
530-
ctrl.Shift(shift);
531522

532523
configuration_previous.Shift(shift);
533524

@@ -815,7 +806,6 @@ void Batch::ShiftResizeTrajectory(int new_head, int new_length) {
815806
acceleration_cache_.Reset();
816807
act_cache_.Reset();
817808
times_cache_.Reset();
818-
ctrl_cache_.Reset();
819809
sensor_measurement_cache_.Reset();
820810
sensor_prediction_cache_.Reset();
821811
sensor_mask_cache_.Reset();
@@ -831,7 +821,6 @@ void Batch::ShiftResizeTrajectory(int new_head, int new_length) {
831821
acceleration_cache_.SetLength(length);
832822
act_cache_.SetLength(length);
833823
times_cache_.SetLength(length);
834-
ctrl_cache_.SetLength(length);
835824
sensor_measurement_cache_.SetLength(length);
836825
sensor_prediction_cache_.SetLength(length);
837826
sensor_mask_cache_.SetLength(length);
@@ -846,7 +835,6 @@ void Batch::ShiftResizeTrajectory(int new_head, int new_length) {
846835
acceleration_cache_.Set(acceleration.Get(i), i);
847836
act_cache_.Set(act.Get(i), i);
848837
times_cache_.Set(times.Get(i), i);
849-
ctrl_cache_.Set(ctrl.Get(i), i);
850838
sensor_measurement_cache_.Set(sensor_measurement.Get(i), i);
851839
sensor_prediction_cache_.Set(sensor_prediction.Get(i), i);
852840
sensor_mask_cache_.Set(sensor_mask.Get(i), i);
@@ -867,7 +855,6 @@ void Batch::ShiftResizeTrajectory(int new_head, int new_length) {
867855
acceleration.Set(acceleration_cache_.Get(new_head + i), i);
868856
act.Set(act_cache_.Get(new_head + i), i);
869857
times.Set(times_cache_.Get(new_head + i), i);
870-
ctrl.Set(ctrl_cache_.Get(new_head + i), i);
871858
sensor_measurement.Set(sensor_measurement_cache_.Get(new_head + i), i);
872859
sensor_prediction.Set(sensor_prediction_cache_.Get(new_head + i), i);
873860
sensor_mask.Set(sensor_mask_cache_.Get(new_head + i), i);

mjpc/estimators/batch.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
namespace mjpc {
3333

3434
// max filter history
35-
inline constexpr int kMaxFilterHistory = 128;
35+
inline constexpr int kMaxFilterHistory = 64;
3636

3737
// ----- batch estimator ----- //
3838
// based on: "Physically-Consistent Sensor Fusion in Contact-Rich Behaviors"
@@ -239,7 +239,6 @@ class Batch : public Direct, public Estimator {
239239
DirectTrajectory<double> acceleration_cache_; // nv x T
240240
DirectTrajectory<double> act_cache_; // na x T
241241
DirectTrajectory<double> times_cache_; // 1 x T
242-
DirectTrajectory<double> ctrl_cache_; // nu x T
243242
DirectTrajectory<double> sensor_measurement_cache_; // ns x T
244243
DirectTrajectory<double> sensor_prediction_cache_; // ns x T
245244
DirectTrajectory<int> sensor_mask_cache_; // num_sensor x T

mjpc/grpc/direct.proto

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,14 @@ message Data {
5555
repeated double velocity = 2 [packed = true];
5656
repeated double acceleration = 3 [packed = true];
5757
repeated double time = 4 [packed = true];
58-
repeated double ctrl = 5 [packed = true];
59-
repeated double configuration_previous = 6 [packed = true];
60-
repeated double sensor_measurement = 7 [packed = true];
61-
repeated double sensor_prediction = 8 [packed = true];
62-
repeated int32 sensor_mask = 9 [packed = true];
63-
repeated double force_measurement = 10 [packed = true];
64-
repeated double force_prediction = 11 [packed = true];
65-
repeated double parameters = 12 [packed = true];
66-
repeated double parameters_previous = 13 [packed = true];
58+
repeated double configuration_previous = 5 [packed = true];
59+
repeated double sensor_measurement = 6 [packed = true];
60+
repeated double sensor_prediction = 7 [packed = true];
61+
repeated int32 sensor_mask = 8 [packed = true];
62+
repeated double force_measurement = 9 [packed = true];
63+
repeated double force_prediction = 10 [packed = true];
64+
repeated double parameters = 11 [packed = true];
65+
repeated double parameters_previous = 12 [packed = true];
6766
}
6867

6968
message DataRequest {

mjpc/grpc/direct_service.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -176,19 +176,6 @@ grpc::Status DirectService::Data(grpc::ServerContext* context,
176176
double* time = optimizer_.times.Get(index);
177177
output->add_time(time[0]);
178178

179-
// set ctrl
180-
int nu = optimizer_.model->nu;
181-
if (input.ctrl_size() > 0) {
182-
CHECK_SIZE("ctrl", nu, input.ctrl_size());
183-
optimizer_.ctrl.Set(input.ctrl().data(), index);
184-
}
185-
186-
// get ctrl
187-
double* ctrl = optimizer_.ctrl.Get(index);
188-
for (int i = 0; i < nu; i++) {
189-
output->add_ctrl(ctrl[i]);
190-
}
191-
192179
// set previous configuration
193180
if (input.configuration_previous_size() > 0) {
194181
CHECK_SIZE("configuration_previous", nq,

mjpc/test/direct/direct_sensor_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ TEST(SensorCost, Particle) {
5858
// copy configuration, measurement
5959
mju_copy(optimizer.configuration.Data(), sim.qpos.Data(), nq * T);
6060
mju_copy(optimizer.sensor_measurement.Data(), sim.sensor.Data(), ns * T);
61-
mju_copy(optimizer.ctrl.Data(), sim.ctrl.Data(), model->nu * T);
6261

6362
// corrupt configurations
6463
absl::BitGen gen_;

python/mujoco_mpc/demos/direct/particle_smoother.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import mujoco
1615
import matplotlib.pyplot as plt
1716
import mediapy as media
17+
import mujoco
18+
from mujoco_mpc import direct as direct_lib
1819
import numpy as np
1920

20-
# set current directory to mjpc/python/mujoco_mpc
21-
from mujoco_mpc import direct as direct_lib
2221
# %%
2322
# 1D Particle Model
2423
xml = """
@@ -167,7 +166,6 @@
167166
optimizer.data(
168167
t,
169168
configuration=qinit[:, t],
170-
ctrl=ctrl[:, t],
171169
sensor_measurement=noisy_sensor[:, t],
172170
force_measurement=qfrc[:, t],
173171
time=np.array([time[t]]),

python/mujoco_mpc/demos/direct/particle_trajopt.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import mujoco
16-
import numpy as np
1715
import matplotlib.pyplot as plt
1816
import mediapy as media
19-
20-
# set current directory to mjpc/python/mujoco_mpc
17+
import mujoco
2118
from mujoco_mpc import direct as direct_lib
19+
import numpy as np
20+
21+
2222
# %%
2323
# 2D Particle Model
2424
xml = """
@@ -163,7 +163,6 @@
163163
data_ = optimizer.data(
164164
t,
165165
configuration=qt,
166-
ctrl=ct,
167166
sensor_measurement=st,
168167
sensor_mask=mt,
169168
force_measurement=ft,

python/mujoco_mpc/demos/filter/particle_drop.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import mujoco
16-
import numpy as np
1715
import matplotlib.pyplot as plt
1816
import mediapy as media
19-
20-
# set current directory to mjpc/python/mujoco_mpc
17+
import mujoco
2118
from mujoco_mpc import filter as filter_lib
19+
import numpy as np
2220

2321
# %%
2422
xml = """

0 commit comments

Comments
 (0)