Skip to content

Commit 41505e1

Browse files
committed
elite average -> trajectory[0]
1 parent 1ecabed commit 41505e1

File tree

2 files changed

+29
-72
lines changed

2 files changed

+29
-72
lines changed

mjpc/planners/cross_entropy/planner.cc

Lines changed: 26 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,6 @@ void CrossEntropyPlanner::Allocate() {
106106
trajectory[i].Allocate(kMaxTrajectoryHorizon);
107107
candidate_policy[i].Allocate(model, *task, kMaxTrajectoryHorizon);
108108
}
109-
110-
// elite average trajectory
111-
elite_avg.Initialize(num_state, model->nu, task->num_residual,
112-
task->num_trace, kMaxTrajectoryHorizon);
113-
elite_avg.Allocate(kMaxTrajectoryHorizon);
114109
}
115110

116111
// reset memory to zeros
@@ -143,7 +138,6 @@ void CrossEntropyPlanner::Reset(int horizon,
143138
trajectory[i].Reset(kMaxTrajectoryHorizon);
144139
candidate_policy[i].Reset(horizon);
145140
}
146-
elite_avg.Reset(kMaxTrajectoryHorizon);
147141

148142
for (const auto& d : data_) {
149143
mju_zero(d->ctrl, model->nu);
@@ -161,11 +155,6 @@ void CrossEntropyPlanner::SetState(const State& state) {
161155

162156
// optimize nominal policy using random sampling
163157
void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
164-
// check horizon
165-
if (horizon != elite_avg.horizon) {
166-
NominalTrajectory(horizon, pool);
167-
}
168-
169158
// if num_trajectory_ has changed, use it in this new iteration.
170159
// num_trajectory_ might change while this function runs. Keep it constant
171160
// for the duration of this function.
@@ -202,12 +191,13 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
202191

203192
// sort so that the first ncandidates elements are the best candidates, and
204193
// the rest are in an unspecified order
205-
std::partial_sort(
206-
trajectory_order.begin(), trajectory_order.begin() + num_trajectory,
207-
trajectory_order.begin() + num_trajectory,
208-
[&trajectory = trajectory](int a, int b) {
209-
return trajectory[a].total_return < trajectory[b].total_return;
210-
});
194+
std::partial_sort(trajectory_order.begin() + 1,
195+
trajectory_order.begin() + 1 + num_trajectory,
196+
trajectory_order.begin() + 1 + num_trajectory,
197+
[&trajectory = trajectory](int a, int b) {
198+
return trajectory[a].total_return <
199+
trajectory[b].total_return;
200+
});
211201

212202
// stop timer
213203
rollouts_compute_time = GetDuration(rollouts_start);
@@ -220,66 +210,29 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
220210
int num_spline_points = resampled_policy.num_spline_points;
221211
int num_parameters = resampled_policy.num_parameters;
222212

223-
// reset parameters scratch to zero
224-
std::fill(parameters_scratch.begin(), parameters_scratch.end(), 0.0);
225-
226-
// reset elite average
227-
elite_avg.Reset(horizon);
213+
// averaged return over elites
214+
double avg_return = 0.0;
228215

229-
// set elite average trajectory times
230-
for (int tt = 0; tt <= horizon; tt++) {
231-
elite_avg.times[tt] = time + tt * model->opt.timestep;
232-
}
233-
234-
// best elite
235-
int idx = trajectory_order[0];
236-
237-
// add parameters
238-
mju_copy(parameters_scratch.data(), candidate_policy[idx].parameters.data(),
239-
num_parameters);
216+
// reset parameters scratch
217+
std::fill(parameters_scratch.begin(), parameters_scratch.end(), 0.0);
240218

241-
// copy first elite trajectory
242-
mju_copy(elite_avg.actions.data(), trajectory[idx].actions.data(),
243-
model->nu * (horizon - 1));
244-
mju_copy(elite_avg.trace.data(), trajectory[idx].trace.data(),
245-
trajectory[idx].dim_trace * horizon);
246-
mju_copy(elite_avg.residual.data(), trajectory[idx].residual.data(),
247-
elite_avg.dim_residual * horizon);
248-
mju_copy(elite_avg.costs.data(), trajectory[idx].costs.data(), horizon);
249-
elite_avg.total_return = trajectory[idx].total_return;
250-
251-
// loop over remaining elites to compute average
252-
for (int i = 1; i < n_elite; i++) {
219+
// loop over elites to compute average
220+
for (int i = 0; i < n_elite; i++) {
253221
// ordered trajectory index
254222
int idx = trajectory_order[i];
255223

256224
// add parameters
257225
mju_addTo(parameters_scratch.data(),
258226
candidate_policy[idx].parameters.data(), num_parameters);
259227

260-
// add elite trajectory
261-
mju_addTo(elite_avg.actions.data(), trajectory[idx].actions.data(),
262-
model->nu * (horizon - 1));
263-
mju_addTo(elite_avg.trace.data(), trajectory[idx].trace.data(),
264-
trajectory[idx].dim_trace * horizon);
265-
mju_addTo(elite_avg.residual.data(), trajectory[idx].residual.data(),
266-
elite_avg.dim_residual * horizon);
267-
mju_addTo(elite_avg.costs.data(), trajectory[idx].costs.data(), horizon);
268-
elite_avg.total_return += trajectory[idx].total_return;
228+
// add total return
229+
avg_return += trajectory[idx].total_return;
269230
}
270231

271232
// normalize
272233
mju_scl(parameters_scratch.data(), parameters_scratch.data(), 1.0 / n_elite,
273234
num_parameters);
274-
mju_scl(elite_avg.actions.data(), elite_avg.actions.data(), 1.0 / n_elite,
275-
model->nu * (horizon - 1));
276-
mju_scl(elite_avg.trace.data(), elite_avg.trace.data(), 1.0 / n_elite,
277-
elite_avg.dim_trace * horizon);
278-
mju_scl(elite_avg.residual.data(), elite_avg.residual.data(), 1.0 / n_elite,
279-
elite_avg.dim_residual * horizon);
280-
mju_scl(elite_avg.costs.data(), elite_avg.costs.data(), 1.0 / n_elite,
281-
horizon);
282-
elite_avg.total_return /= n_elite;
235+
avg_return /= n_elite;
283236

284237
// loop over elites to compute variance
285238
std::fill(variance.begin(), variance.end(), 0.0); // reset variance to zero
@@ -304,9 +257,8 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
304257
}
305258

306259
// improvement: compare nominal to elite average
307-
improvement = mju_max(
308-
elite_avg.total_return - trajectory[trajectory_order[0]].total_return,
309-
0.0);
260+
improvement =
261+
mju_max(avg_return - trajectory[trajectory_order[0]].total_return, 0.0);
310262

311263
// stop timer
312264
policy_update_compute_time = GetDuration(policy_update_start);
@@ -321,8 +273,9 @@ void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
321273
};
322274

323275
// rollout nominal policy
324-
elite_avg.Rollout(nominal_policy, task, model, data_[0].get(), state.data(),
325-
time, mocap.data(), userdata.data(), horizon);
276+
trajectory[0].Rollout(nominal_policy, task, model, data_[0].get(),
277+
state.data(), time, mocap.data(), userdata.data(),
278+
horizon);
326279
}
327280

328281
// set action from policy
@@ -428,7 +381,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
428381
s.resampled_policy.representation;
429382

430383
// sample noise
431-
s.AddNoiseToPolicy(i, std_min);
384+
if (i > 0) s.AddNoiseToPolicy(i, std_min);
432385
}
433386

434387
// ----- rollout sample policy ----- //
@@ -450,8 +403,10 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
450403
pool.ResetCount();
451404
}
452405

453-
// returns the nominal trajectory (this is the purple trace)
454-
const Trajectory* CrossEntropyPlanner::BestTrajectory() { return &elite_avg; }
406+
// returns the **nominal** trajectory (this is the purple trace)
407+
const Trajectory* CrossEntropyPlanner::BestTrajectory() {
408+
return &trajectory[0];
409+
}
455410

456411
// visualize planner-specific traces
457412
void CrossEntropyPlanner::Traces(mjvScene* scn) {

mjpc/planners/cross_entropy/planner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ class CrossEntropyPlanner : public Planner {
111111

112112
// trajectories
113113
Trajectory trajectory[kMaxTrajectory];
114-
Trajectory elite_avg;
115114

116115
// order of indices of rolled out trajectories, ordered by total return
117116
std::vector<int> trajectory_order;
@@ -139,6 +138,9 @@ class CrossEntropyPlanner : public Planner {
139138

140139
int num_trajectory_;
141140
mutable std::shared_mutex mtx_;
141+
142+
// elite average index
143+
const int elite_average_index_ = 0;
142144
};
143145

144146
} // namespace mjpc

0 commit comments

Comments
 (0)