Skip to content

Commit 8809c54

Browse files
committed
elite_avg -> nominal
1 parent 4a06147 commit 8809c54

File tree

2 files changed

+30
-22
lines changed

2 files changed

+30
-22
lines changed

mjpc/planners/cross_entropy/planner.cc

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ void CrossEntropyPlanner::Allocate() {
106106
trajectory[i].Allocate(kMaxTrajectoryHorizon);
107107
candidate_policy[i].Allocate(model, *task, kMaxTrajectoryHorizon);
108108
}
109+
nominal_trajectory.Initialize(num_state, model->nu, task->num_residual,
110+
task->num_trace, kMaxTrajectoryHorizon);
111+
nominal_trajectory.Allocate(kMaxTrajectoryHorizon);
109112
}
110113

111114
// reset memory to zeros
@@ -138,6 +141,7 @@ void CrossEntropyPlanner::Reset(int horizon,
138141
trajectory[i].Reset(kMaxTrajectoryHorizon);
139142
candidate_policy[i].Reset(horizon);
140143
}
144+
nominal_trajectory.Reset(kMaxTrajectoryHorizon);
141145

142146
for (const auto& d : data_) {
143147
mju_zero(d->ctrl, model->nu);
@@ -161,8 +165,8 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
161165
int num_trajectory = num_trajectory_;
162166

163167
// n_elite_ might change in the GUI - keep constant for in this function
164-
n_elite_ = std::min(n_elite_, num_trajectory - 1);
165-
int n_elite = std::min(n_elite_, num_trajectory - 1);
168+
n_elite_ = std::min(n_elite_, num_trajectory);
169+
int n_elite = std::min(n_elite_, num_trajectory);
166170

167171
// resize number of mjData
168172
ResizeMjData(model, pool.NumThreads());
@@ -191,13 +195,12 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
191195

192196
// sort so that the first ncandidates elements are the best candidates, and
193197
// the rest are in an unspecified order
194-
std::partial_sort(trajectory_order.begin() + 1,
195-
trajectory_order.begin() + 1 + num_trajectory,
196-
trajectory_order.begin() + 1 + num_trajectory,
197-
[&trajectory = trajectory](int a, int b) {
198-
return trajectory[a].total_return <
199-
trajectory[b].total_return;
200-
});
198+
std::partial_sort(
199+
trajectory_order.begin(), trajectory_order.begin() + num_trajectory,
200+
trajectory_order.begin() + num_trajectory,
201+
[&trajectory = trajectory](int a, int b) {
202+
return trajectory[a].total_return < trajectory[b].total_return;
203+
});
201204

202205
// stop timer
203206
rollouts_compute_time = GetDuration(rollouts_start);
@@ -265,17 +268,20 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
265268
}
266269

267270
// compute trajectory using nominal policy
268-
void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
271+
void CrossEntropyPlanner::NominalTrajectory(int horizon) {
269272
// set policy
270273
auto nominal_policy = [&cp = resampled_policy](
271274
double* action, const double* state, double time) {
272275
cp.Action(action, state, time);
273276
};
274277

275278
// rollout nominal policy
276-
trajectory[elite_average_index_].Rollout(
277-
nominal_policy, task, model, data_[0].get(), state.data(), time,
278-
mocap.data(), userdata.data(), horizon);
279+
nominal_trajectory.Rollout(nominal_policy, task, model,
280+
data_[ThreadPool::WorkerId()].get(), state.data(),
281+
time, mocap.data(), userdata.data(), horizon);
282+
}
283+
void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
284+
NominalTrajectory(horizon);
279285
}
280286

281287
// set action from policy
@@ -316,6 +322,8 @@ void CrossEntropyPlanner::ResamplePolicy(int horizon) {
316322

317323
LinearRange(resampled_policy.times.data(), time_shift,
318324
resampled_policy.times[0], num_spline_points);
325+
326+
resampled_policy.representation = policy.representation;
319327
}
320328

321329
// add random noise to nominal policy
@@ -381,7 +389,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
381389
s.resampled_policy.representation;
382390

383391
// sample noise
384-
if (i != s.elite_average_index_) s.AddNoiseToPolicy(i, std_min);
392+
s.AddNoiseToPolicy(i, std_min);
385393
}
386394

387395
// ----- rollout sample policy ----- //
@@ -399,13 +407,17 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
399407
state.data(), time, mocap.data(), userdata.data(), horizon);
400408
});
401409
}
402-
pool.WaitCount(count_before + num_trajectory);
410+
// nominal
411+
pool.Schedule([&s = *this, horizon]() { s.NominalTrajectory(horizon); });
412+
413+
// wait
414+
pool.WaitCount(count_before + num_trajectory + 1);
403415
pool.ResetCount();
404416
}
405417

406418
// returns the **nominal** trajectory (this is the purple trace)
407419
const Trajectory* CrossEntropyPlanner::BestTrajectory() {
408-
return &trajectory[elite_average_index_];
420+
return &nominal_trajectory;
409421
}
410422

411423
// visualize planner-specific traces

mjpc/planners/cross_entropy/planner.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class CrossEntropyPlanner : public Planner {
5757

5858
// compute trajectory using nominal policy
5959
void NominalTrajectory(int horizon, ThreadPool& pool) override;
60+
void NominalTrajectory(int horizon);
6061

6162
// set action from policy
6263
void ActionFromPolicy(double* action, const double* state, double time,
@@ -111,6 +112,7 @@ class CrossEntropyPlanner : public Planner {
111112

112113
// trajectories
113114
Trajectory trajectory[kMaxTrajectory];
115+
Trajectory nominal_trajectory;
114116

115117
// order of indices of rolled out trajectories, ordered by total return
116118
std::vector<int> trajectory_order;
@@ -128,19 +130,13 @@ class CrossEntropyPlanner : public Planner {
128130
// improvement
129131
double improvement;
130132

131-
// flags
132-
int processed_noise_status;
133-
134133
// timing
135134
std::atomic<double> noise_compute_time;
136135
double rollouts_compute_time;
137136
double policy_update_compute_time;
138137

139138
int num_trajectory_;
140139
mutable std::shared_mutex mtx_;
141-
142-
// elite average index
143-
const int elite_average_index_ = 0;
144140
};
145141

146142
} // namespace mjpc

0 commit comments

Comments
 (0)