@@ -106,6 +106,9 @@ void CrossEntropyPlanner::Allocate() {
106106 trajectory[i].Allocate (kMaxTrajectoryHorizon );
107107 candidate_policy[i].Allocate (model, *task, kMaxTrajectoryHorizon );
108108 }
109+ nominal_trajectory.Initialize (num_state, model->nu , task->num_residual ,
110+ task->num_trace , kMaxTrajectoryHorizon );
111+ nominal_trajectory.Allocate (kMaxTrajectoryHorizon );
109112}
110113
111114// reset memory to zeros
@@ -138,6 +141,7 @@ void CrossEntropyPlanner::Reset(int horizon,
138141 trajectory[i].Reset (kMaxTrajectoryHorizon );
139142 candidate_policy[i].Reset (horizon);
140143 }
144+ nominal_trajectory.Reset (kMaxTrajectoryHorizon );
141145
142146 for (const auto & d : data_) {
143147 mju_zero (d->ctrl , model->nu );
@@ -161,8 +165,8 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
161165 int num_trajectory = num_trajectory_;
162166
163167 // n_elite_ might change in the GUI - keep constant for in this function
164- n_elite_ = std::min (n_elite_, num_trajectory - 1 );
165- int n_elite = std::min (n_elite_, num_trajectory - 1 );
168+ n_elite_ = std::min (n_elite_, num_trajectory);
169+ int n_elite = std::min (n_elite_, num_trajectory);
166170
167171 // resize number of mjData
168172 ResizeMjData (model, pool.NumThreads ());
@@ -191,13 +195,12 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
191195
192196 // sort so that the first ncandidates elements are the best candidates, and
193197 // the rest are in an unspecified order
194- std::partial_sort (trajectory_order.begin () + 1 ,
195- trajectory_order.begin () + 1 + num_trajectory,
196- trajectory_order.begin () + 1 + num_trajectory,
197- [&trajectory = trajectory](int a, int b) {
198- return trajectory[a].total_return <
199- trajectory[b].total_return ;
200- });
198+ std::partial_sort (
199+ trajectory_order.begin (), trajectory_order.begin () + num_trajectory,
200+ trajectory_order.begin () + num_trajectory,
201+ [&trajectory = trajectory](int a, int b) {
202+ return trajectory[a].total_return < trajectory[b].total_return ;
203+ });
201204
202205 // stop timer
203206 rollouts_compute_time = GetDuration (rollouts_start);
@@ -265,17 +268,20 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
265268}
266269
267270// compute trajectory using nominal policy
268- void CrossEntropyPlanner::NominalTrajectory (int horizon, ThreadPool& pool ) {
271+ void CrossEntropyPlanner::NominalTrajectory (int horizon) {
269272 // set policy
270273 auto nominal_policy = [&cp = resampled_policy](
271274 double * action, const double * state, double time) {
272275 cp.Action (action, state, time);
273276 };
274277
275278 // rollout nominal policy
276- trajectory[elite_average_index_].Rollout (
277- nominal_policy, task, model, data_[0 ].get (), state.data (), time,
278- mocap.data (), userdata.data (), horizon);
279+ nominal_trajectory.Rollout (nominal_policy, task, model,
280+ data_[ThreadPool::WorkerId ()].get (), state.data (),
281+ time, mocap.data (), userdata.data (), horizon);
282+ }
283+ void CrossEntropyPlanner::NominalTrajectory (int horizon, ThreadPool& pool) {
284+ NominalTrajectory (horizon);
279285}
280286
281287// set action from policy
@@ -316,6 +322,8 @@ void CrossEntropyPlanner::ResamplePolicy(int horizon) {
316322
317323 LinearRange (resampled_policy.times .data (), time_shift,
318324 resampled_policy.times [0 ], num_spline_points);
325+
326+ resampled_policy.representation = policy.representation ;
319327}
320328
321329// add random noise to nominal policy
@@ -381,7 +389,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
381389 s.resampled_policy .representation ;
382390
383391 // sample noise
384- if (i != s. elite_average_index_ ) s.AddNoiseToPolicy (i, std_min);
392+ s.AddNoiseToPolicy (i, std_min);
385393 }
386394
387395 // ----- rollout sample policy ----- //
@@ -399,13 +407,17 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
399407 state.data (), time, mocap.data (), userdata.data (), horizon);
400408 });
401409 }
402- pool.WaitCount (count_before + num_trajectory);
410+ // nominal
411+ pool.Schedule ([&s = *this , horizon]() { s.NominalTrajectory (horizon); });
412+
413+ // wait
414+ pool.WaitCount (count_before + num_trajectory + 1 );
403415 pool.ResetCount ();
404416}
405417
406418// returns the **nominal** trajectory (this is the purple trace)
407419const Trajectory* CrossEntropyPlanner::BestTrajectory () {
408- return &trajectory[elite_average_index_] ;
420+ return &nominal_trajectory ;
409421}
410422
411423// visualize planner-specific traces
0 commit comments