@@ -173,18 +173,32 @@ void SampleGradientPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
173173 num_gradient_ = std::min (num_gradient_, num_trajectory - 1 );
174174 int num_gradient = num_gradient_;
175175
176+ // number of noisy policies
177+ int num_noisy = num_trajectory - num_gradient;
178+
176179 // resize number of mjData
177180 ResizeMjData (model, pool.NumThreads ());
178181
179182 // copy nominal policy
180- policy.num_parameters = model->nu * policy.num_spline_points ;
183+ int num_spline_points = policy.num_spline_points ;
184+ PolicyRepresentation representation = policy.representation ;
185+ policy.num_parameters = model->nu * num_spline_points;
181186 {
182187 const std::shared_lock<std::shared_mutex> lock (mtx_);
183- resampled_policy.CopyFrom (policy, policy. num_spline_points );
188+ resampled_policy.CopyFrom (policy, num_spline_points);
184189 }
185190
186191 // resample nominal policy to current time
187- this ->ResamplePolicy (horizon);
192+ this ->ResamplePolicy (resampled_policy, horizon, num_spline_points,
193+ representation);
194+
195+ // resample gradient policies to current time
196+ // TODO(taylor): a bit faster to do in Rollouts, but needs more scratch to be
197+ // memory safe
198+ for (int i = 0 ; i < num_gradient; i++) {
199+ this ->ResamplePolicy (candidate_policy[num_noisy + i], horizon,
200+ num_spline_points, representation);
201+ }
188202
189203 // ----- roll out noisy policies ----- //
190204 // start timer
@@ -196,15 +210,6 @@ void SampleGradientPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
196210 // stop timer
197211 rollouts_compute_time = GetDuration (perturb_rollouts_start);
198212
199- // start timer
200- auto gradient_rollouts_start = std::chrono::steady_clock::now ();
201-
202- // roll out interpolated policies between Cauchy and Newton points
203- this ->GradientCandidates (num_trajectory, num_gradient, horizon, pool);
204-
205- // stop timer
206- gradient_candidates_compute_time = GetDuration (gradient_rollouts_start);
207-
208213 // ----- update policy ----- //
209214 // start timer
210215 auto policy_update_start = std::chrono::steady_clock::now ();
@@ -256,6 +261,16 @@ void SampleGradientPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
256261
257262 // stop timer
258263 policy_update_compute_time = GetDuration (policy_update_start);
264+
265+ // ----- compute gradient candidate policies ----- //
266+ // start timer
267+ auto gradient_start = std::chrono::steady_clock::now ();
268+
269+ // candidate policies
270+ this ->GradientCandidates (num_trajectory, num_gradient, horizon, pool);
271+
272+ // stop timer
273+ gradient_candidates_compute_time = GetDuration (gradient_start);
259274}
260275
261276// compute trajectory using nominal policy
@@ -285,10 +300,11 @@ void SampleGradientPlanner::ActionFromPolicy(double* action,
285300}
286301
287302// update policy via resampling
288- void SampleGradientPlanner::ResamplePolicy (int horizon) {
289- // dimensions
290- int num_parameters = resampled_policy.num_parameters ;
291- int num_spline_points = resampled_policy.num_spline_points ;
303+ void SampleGradientPlanner::ResamplePolicy (
304+ SamplingPolicy& policy, int horizon, int num_spline_points,
305+ PolicyRepresentation representation) {
306+ // dimension
307+ int num_parameters = model->nu * num_spline_points;
292308
293309 // time
294310 double nominal_time = time;
@@ -298,23 +314,27 @@ void SampleGradientPlanner::ResamplePolicy(int horizon) {
298314 // get spline points
299315 for (int t = 0 ; t < num_spline_points; t++) {
300316 times_scratch[t] = nominal_time;
301- resampled_policy .Action (DataAt (parameters_scratch, t * model->nu ), nullptr ,
302- nominal_time);
317+ policy .Action (DataAt (parameters_scratch, t * model->nu ), nullptr ,
318+ nominal_time);
303319 nominal_time += time_shift;
304320 }
305321
306322 // copy resampled policy parameters
307- mju_copy (resampled_policy .parameters .data (), parameters_scratch.data (),
323+ mju_copy (policy .parameters .data (), parameters_scratch.data (),
308324 num_parameters);
309- mju_copy (resampled_policy .times .data (), times_scratch.data (),
325+ mju_copy (policy .times .data (), times_scratch.data (),
310326 num_spline_points);
311327
312328 // time step linear range
313- LinearRange (resampled_policy.times .data (), time_shift,
314- resampled_policy.times [0 ], num_spline_points);
329+ LinearRange (policy.times .data (), time_shift,
330+ policy.times [0 ], num_spline_points);
331+
332+ // set dimensions
333+ policy.num_parameters = num_parameters;
334+ policy.num_spline_points = num_spline_points;
315335
316336 // representation
317- resampled_policy .representation = policy. representation ;
337+ policy .representation = representation;
318338}
319339
320340// add random noise to nominal policy
@@ -365,7 +385,7 @@ void SampleGradientPlanner::Rollouts(int num_trajectory, int num_gradient,
365385 &mocap = this ->mocap , &userdata = this ->userdata , horizon,
366386 idx_nominal = this ->idx_nominal , num_trajectory,
367387 num_gradient, i]() {
368- // nominal + noisy policies
388+ // nominal and noisy policies
369389 if (i < num_trajectory - num_gradient) {
370390 // copy nominal policy
371391 s.candidate_policy [i].CopyFrom (s.resampled_policy ,
@@ -404,6 +424,7 @@ void SampleGradientPlanner::GradientCandidates(int num_trajectory,
404424
405425 // number of parameters
406426 int num_parameters = resampled_policy.num_parameters ;
427+ int num_spline_points = resampled_policy.num_spline_points ;
407428
408429 // cache old gradient
409430 mju_copy (gradient_previous.data (), gradient.data (), num_parameters);
@@ -428,7 +449,7 @@ void SampleGradientPlanner::GradientCandidates(int num_trajectory,
428449
429450 // normalize gradient
430451 // TODO(taylor): should we normalize?
431- // mju_normalize(gradient.data(), num_parameters);
452+ mju_normalize (gradient.data (), num_parameters);
432453
433454 // compute step sizes along gradient
434455 std::vector<double > step_size (num_gradient);
@@ -442,8 +463,7 @@ void SampleGradientPlanner::GradientCandidates(int num_trajectory,
442463 // these candidates will be evaluated at the next planning iteration
443464 for (int i = num_noisy; i < num_trajectory; i++) {
444465 // copy nominal policy
445- candidate_policy[i].CopyFrom (resampled_policy,
446- resampled_policy.num_spline_points );
466+ candidate_policy[i].CopyFrom (resampled_policy, num_spline_points);
447467 candidate_policy[i].representation = resampled_policy.representation ;
448468
449469 // scaling
@@ -453,15 +473,15 @@ void SampleGradientPlanner::GradientCandidates(int num_trajectory,
453473
454474 // gradient step
455475 mju_addToScl (candidate_policy[i].parameters .data (), gradient.data (),
456- -scaling * gradient_filter, resampled_policy. num_parameters );
476+ -scaling * gradient_filter, num_parameters);
457477
458478 // TODO(taylor): resample the gradient_previous?
459479 mju_addToScl (candidate_policy[i].parameters .data (),
460480 gradient_previous.data (), -scaling * (1.0 - gradient_filter),
461- resampled_policy. num_parameters );
481+ num_parameters);
462482
463483 // clamp parameters
464- for (int t = 0 ; t < resampled_policy. num_spline_points ; t++) {
484+ for (int t = 0 ; t < num_spline_points; t++) {
465485 Clamp (DataAt (candidate_policy[i].parameters , t * model->nu ),
466486 model->actuator_ctrlrange , model->nu );
467487 }
0 commit comments