@@ -161,8 +161,8 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
161161 int num_trajectory = num_trajectory_;
162162
163163 // n_elite_ might change in the GUI - keep constant for in this function
164- n_elite_ = std::min (n_elite_, num_trajectory);
165- int n_elite = std::min (n_elite_, num_trajectory);
164+ n_elite_ = std::min (n_elite_, num_trajectory - 1 );
165+ int n_elite = std::min (n_elite_, num_trajectory - 1 );
166166
167167 // resize number of mjData
168168 ResizeMjData (model, pool.NumThreads ());
@@ -273,9 +273,9 @@ void CrossEntropyPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
273273 };
274274
275275 // rollout nominal policy
276- trajectory[0 ].Rollout (nominal_policy, task, model, data_[ 0 ]. get (),
277- state. data (), time, mocap. data (), userdata .data (),
278- horizon);
276+ trajectory[elite_average_index_ ].Rollout (
277+ nominal_policy, task, model, data_[ 0 ]. get (), state .data (), time ,
278+ mocap. data (), userdata. data (), horizon);
279279}
280280
281281// set action from policy
@@ -381,7 +381,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
381381 s.resampled_policy .representation ;
382382
383383 // sample noise
384- if (i > 0 ) s.AddNoiseToPolicy (i, std_min);
384+ if (i != s. elite_average_index_ ) s.AddNoiseToPolicy (i, std_min);
385385 }
386386
387387 // ----- rollout sample policy ----- //
@@ -405,7 +405,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
405405
406406// returns the **nominal** trajectory (this is the purple trace)
407407const Trajectory* CrossEntropyPlanner::BestTrajectory () {
408- return &trajectory[0 ];
408+ return &trajectory[elite_average_index_ ];
409409}
410410
411411// visualize planner-specific traces
0 commit comments