Skip to content

Commit 14adceb

Browse files
authored
Merge pull request #265 from thowell/ce
Cross entropy planner: Fix trace dimension
2 parents a8bdcc1 + 6333379 commit 14adceb

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

mjpc/planners/cross_entropy/planner.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
241241
// copy first elite trajectory
242242
mju_copy(elite_avg.actions.data(), trajectory[idx].actions.data(),
243243
model->nu * (horizon - 1));
244-
mju_copy(elite_avg.trace.data(), trajectory[idx].trace.data(), 3 * horizon);
244+
mju_copy(elite_avg.trace.data(), trajectory[idx].trace.data(),
245+
trajectory[idx].dim_trace * horizon);
245246
mju_copy(elite_avg.residual.data(), trajectory[idx].residual.data(),
246247
elite_avg.dim_residual * horizon);
247248
mju_copy(elite_avg.costs.data(), trajectory[idx].costs.data(), horizon);
@@ -260,7 +261,7 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
260261
mju_addTo(elite_avg.actions.data(), trajectory[idx].actions.data(),
261262
model->nu * (horizon - 1));
262263
mju_addTo(elite_avg.trace.data(), trajectory[idx].trace.data(),
263-
3 * horizon);
264+
trajectory[idx].dim_trace * horizon);
264265
mju_addTo(elite_avg.residual.data(), trajectory[idx].residual.data(),
265266
elite_avg.dim_residual * horizon);
266267
mju_addTo(elite_avg.costs.data(), trajectory[idx].costs.data(), horizon);
@@ -273,7 +274,7 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
273274
mju_scl(elite_avg.actions.data(), elite_avg.actions.data(), 1.0 / n_elite,
274275
model->nu * (horizon - 1));
275276
mju_scl(elite_avg.trace.data(), elite_avg.trace.data(), 1.0 / n_elite,
276-
3 * horizon);
277+
elite_avg.dim_trace * horizon);
277278
mju_scl(elite_avg.residual.data(), elite_avg.residual.data(), 1.0 / n_elite,
278279
elite_avg.dim_residual * horizon);
279280
mju_scl(elite_avg.costs.data(), elite_avg.costs.data(), 1.0 / n_elite,

0 commit comments

Comments
 (0)