Skip to content

Commit bd6eda6

Browse files
authored
Merge pull request #284 from thowell/ce
Cross entropy planner changes
2 parents 4384ef0 + ed9e635 commit bd6eda6

File tree

2 files changed

+35
-70
lines changed

2 files changed

+35
-70
lines changed

mjpc/planners/cross_entropy/planner.cc

Lines changed: 33 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,9 @@ void CrossEntropyPlanner::Allocate() {
106106
trajectory[i].Allocate(kMaxTrajectoryHorizon);
107107
candidate_policy[i].Allocate(model, *task, kMaxTrajectoryHorizon);
108108
}
109-
110-
// elite average trajectory
111-
elite_avg.Initialize(num_state, model->nu, task->num_residual,
112-
task->num_trace, kMaxTrajectoryHorizon);
113-
elite_avg.Allocate(kMaxTrajectoryHorizon);
109+
nominal_trajectory.Initialize(num_state, model->nu, task->num_residual,
110+
task->num_trace, kMaxTrajectoryHorizon);
111+
nominal_trajectory.Allocate(kMaxTrajectoryHorizon);
114112
}
115113

116114
// reset memory to zeros
@@ -143,7 +141,7 @@ void CrossEntropyPlanner::Reset(int horizon,
143141
trajectory[i].Reset(kMaxTrajectoryHorizon);
144142
candidate_policy[i].Reset(horizon);
145143
}
146-
elite_avg.Reset(kMaxTrajectoryHorizon);
144+
nominal_trajectory.Reset(kMaxTrajectoryHorizon);
147145

148146
for (const auto& d : data_) {
149147
mju_zero(d->ctrl, model->nu);
@@ -161,11 +159,6 @@ void CrossEntropyPlanner::SetState(const State& state) {
161159

162160
// optimize nominal policy using random sampling
163161
void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
164-
// check horizon
165-
if (horizon != elite_avg.horizon) {
166-
NominalTrajectory(horizon, pool);
167-
}
168-
169162
// if num_trajectory_ has changed, use it in this new iteration.
170163
// num_trajectory_ might change while this function runs. Keep it constant
171164
// for the duration of this function.
@@ -220,66 +213,29 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
220213
int num_spline_points = resampled_policy.num_spline_points;
221214
int num_parameters = resampled_policy.num_parameters;
222215

223-
// reset parameters scratch to zero
224-
std::fill(parameters_scratch.begin(), parameters_scratch.end(), 0.0);
225-
226-
// reset elite average
227-
elite_avg.Reset(horizon);
228-
229-
// set elite average trajectory times
230-
for (int tt = 0; tt <= horizon; tt++) {
231-
elite_avg.times[tt] = time + tt * model->opt.timestep;
232-
}
233-
234-
// best elite
235-
int idx = trajectory_order[0];
216+
// averaged return over elites
217+
double avg_return = 0.0;
236218

237-
// add parameters
238-
mju_copy(parameters_scratch.data(), candidate_policy[idx].parameters.data(),
239-
num_parameters);
219+
// reset parameters scratch
220+
std::fill(parameters_scratch.begin(), parameters_scratch.end(), 0.0);
240221

241-
// copy first elite trajectory
242-
mju_copy(elite_avg.actions.data(), trajectory[idx].actions.data(),
243-
model->nu * (horizon - 1));
244-
mju_copy(elite_avg.trace.data(), trajectory[idx].trace.data(),
245-
trajectory[idx].dim_trace * horizon);
246-
mju_copy(elite_avg.residual.data(), trajectory[idx].residual.data(),
247-
elite_avg.dim_residual * horizon);
248-
mju_copy(elite_avg.costs.data(), trajectory[idx].costs.data(), horizon);
249-
elite_avg.total_return = trajectory[idx].total_return;
250-
251-
// loop over remaining elites to compute average
252-
for (int i = 1; i < n_elite; i++) {
222+
// loop over elites to compute average
223+
for (int i = 0; i < n_elite; i++) {
253224
// ordered trajectory index
254225
int idx = trajectory_order[i];
255226

256227
// add parameters
257228
mju_addTo(parameters_scratch.data(),
258229
candidate_policy[idx].parameters.data(), num_parameters);
259230

260-
// add elite trajectory
261-
mju_addTo(elite_avg.actions.data(), trajectory[idx].actions.data(),
262-
model->nu * (horizon - 1));
263-
mju_addTo(elite_avg.trace.data(), trajectory[idx].trace.data(),
264-
trajectory[idx].dim_trace * horizon);
265-
mju_addTo(elite_avg.residual.data(), trajectory[idx].residual.data(),
266-
elite_avg.dim_residual * horizon);
267-
mju_addTo(elite_avg.costs.data(), trajectory[idx].costs.data(), horizon);
268-
elite_avg.total_return += trajectory[idx].total_return;
231+
// add total return
232+
avg_return += trajectory[idx].total_return;
269233
}
270234

271235
// normalize
272236
mju_scl(parameters_scratch.data(), parameters_scratch.data(), 1.0 / n_elite,
273237
num_parameters);
274-
mju_scl(elite_avg.actions.data(), elite_avg.actions.data(), 1.0 / n_elite,
275-
model->nu * (horizon - 1));
276-
mju_scl(elite_avg.trace.data(), elite_avg.trace.data(), 1.0 / n_elite,
277-
elite_avg.dim_trace * horizon);
278-
mju_scl(elite_avg.residual.data(), elite_avg.residual.data(), 1.0 / n_elite,
279-
elite_avg.dim_residual * horizon);
280-
mju_scl(elite_avg.costs.data(), elite_avg.costs.data(), 1.0 / n_elite,
281-
horizon);
282-
elite_avg.total_return /= n_elite;
238+
avg_return /= n_elite;
283239

284240
// loop over elites to compute variance
285241
std::fill(variance.begin(), variance.end(), 0.0); // reset variance to zero
@@ -304,25 +260,28 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
304260
}
305261

306262
// improvement: compare nominal to elite average
307-
improvement = mju_max(
308-
elite_avg.total_return - trajectory[trajectory_order[0]].total_return,
309-
0.0);
263+
improvement =
264+
mju_max(avg_return - trajectory[trajectory_order[0]].total_return, 0.0);
310265

311266
// stop timer
312267
policy_update_compute_time = GetDuration(policy_update_start);
313268
}
314269

315270
// compute trajectory using nominal policy
316-
void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
271+
void CrossEntropyPlanner::NominalTrajectory(int horizon) {
317272
// set policy
318273
auto nominal_policy = [&cp = resampled_policy](
319274
double* action, const double* state, double time) {
320275
cp.Action(action, state, time);
321276
};
322277

323278
// rollout nominal policy
324-
elite_avg.Rollout(nominal_policy, task, model, data_[0].get(), state.data(),
325-
time, mocap.data(), userdata.data(), horizon);
279+
nominal_trajectory.Rollout(nominal_policy, task, model,
280+
data_[ThreadPool::WorkerId()].get(), state.data(),
281+
time, mocap.data(), userdata.data(), horizon);
282+
}
283+
void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
284+
NominalTrajectory(horizon);
326285
}
327286

328287
// set action from policy
@@ -363,6 +322,8 @@ void CrossEntropyPlanner::ResamplePolicy(int horizon) {
363322

364323
LinearRange(resampled_policy.times.data(), time_shift,
365324
resampled_policy.times[0], num_spline_points);
325+
326+
resampled_policy.representation = policy.representation;
366327
}
367328

368329
// add random noise to nominal policy
@@ -446,12 +407,18 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
446407
state.data(), time, mocap.data(), userdata.data(), horizon);
447408
});
448409
}
449-
pool.WaitCount(count_before + num_trajectory);
410+
// nominal
411+
pool.Schedule([&s = *this, horizon]() { s.NominalTrajectory(horizon); });
412+
413+
// wait
414+
pool.WaitCount(count_before + num_trajectory + 1);
450415
pool.ResetCount();
451416
}
452417

453-
// returns the nominal trajectory (this is the purple trace)
454-
const Trajectory* CrossEntropyPlanner::BestTrajectory() { return &elite_avg; }
418+
// returns the **nominal** trajectory (this is the purple trace)
419+
const Trajectory* CrossEntropyPlanner::BestTrajectory() {
420+
return &nominal_trajectory;
421+
}
455422

456423
// visualize planner-specific traces
457424
void CrossEntropyPlanner::Traces(mjvScene* scn) {

mjpc/planners/cross_entropy/planner.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class CrossEntropyPlanner : public Planner {
5757

5858
// compute trajectory using nominal policy
5959
void NominalTrajectory(int horizon, ThreadPool& pool) override;
60+
void NominalTrajectory(int horizon);
6061

6162
// set action from policy
6263
void ActionFromPolicy(double* action, const double* state, double time,
@@ -111,7 +112,7 @@ class CrossEntropyPlanner : public Planner {
111112

112113
// trajectories
113114
Trajectory trajectory[kMaxTrajectory];
114-
Trajectory elite_avg;
115+
Trajectory nominal_trajectory;
115116

116117
// order of indices of rolled out trajectories, ordered by total return
117118
std::vector<int> trajectory_order;
@@ -129,9 +130,6 @@ class CrossEntropyPlanner : public Planner {
129130
// improvement
130131
double improvement;
131132

132-
// flags
133-
int processed_noise_status;
134-
135133
// timing
136134
std::atomic<double> noise_compute_time;
137135
double rollouts_compute_time;

0 commit comments

Comments
 (0)