Skip to content

Commit 3eb32f9

Browse files
committed
compute gradient candidates after policy update
1 parent 1737fcd commit 3eb32f9

File tree

2 files changed

+54
-32
lines changed

2 files changed

+54
-32
lines changed

mjpc/planners/sample_gradient/planner.cc

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -173,18 +173,32 @@ void SampleGradientPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
173173
num_gradient_ = std::min(num_gradient_, num_trajectory - 1);
174174
int num_gradient = num_gradient_;
175175

176+
// number of noisy policies
177+
int num_noisy = num_trajectory - num_gradient;
178+
176179
// resize number of mjData
177180
ResizeMjData(model, pool.NumThreads());
178181

179182
// copy nominal policy
180-
policy.num_parameters = model->nu * policy.num_spline_points;
183+
int num_spline_points = policy.num_spline_points;
184+
PolicyRepresentation representation = policy.representation;
185+
policy.num_parameters = model->nu * num_spline_points;
181186
{
182187
const std::shared_lock<std::shared_mutex> lock(mtx_);
183-
resampled_policy.CopyFrom(policy, policy.num_spline_points);
188+
resampled_policy.CopyFrom(policy, num_spline_points);
184189
}
185190

186191
// resample nominal policy to current time
187-
this->ResamplePolicy(horizon);
192+
this->ResamplePolicy(resampled_policy, horizon, num_spline_points,
193+
representation);
194+
195+
// resample gradient policies to current time
196+
// TODO(taylor): a bit faster to do in Rollouts, but needs more scratch to be
197+
// memory safe
198+
for (int i = 0; i < num_gradient; i++) {
199+
this->ResamplePolicy(candidate_policy[num_noisy + i], horizon,
200+
num_spline_points, representation);
201+
}
188202

189203
// ----- roll out noisy policies ----- //
190204
// start timer
@@ -196,15 +210,6 @@ void SampleGradientPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
196210
// stop timer
197211
rollouts_compute_time = GetDuration(perturb_rollouts_start);
198212

199-
// start timer
200-
auto gradient_rollouts_start = std::chrono::steady_clock::now();
201-
202-
// roll out interpolated policies between Cauchy and Newton points
203-
this->GradientCandidates(num_trajectory, num_gradient, horizon, pool);
204-
205-
// stop timer
206-
gradient_candidates_compute_time = GetDuration(gradient_rollouts_start);
207-
208213
// ----- update policy ----- //
209214
// start timer
210215
auto policy_update_start = std::chrono::steady_clock::now();
@@ -256,6 +261,16 @@ void SampleGradientPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
256261

257262
// stop timer
258263
policy_update_compute_time = GetDuration(policy_update_start);
264+
265+
// ----- compute gradient candidate policies ----- //
266+
// start timer
267+
auto gradient_start = std::chrono::steady_clock::now();
268+
269+
// candidate policies
270+
this->GradientCandidates(num_trajectory, num_gradient, horizon, pool);
271+
272+
// stop timer
273+
gradient_candidates_compute_time = GetDuration(gradient_start);
259274
}
260275

261276
// compute trajectory using nominal policy
@@ -285,10 +300,11 @@ void SampleGradientPlanner::ActionFromPolicy(double* action,
285300
}
286301

287302
// update policy via resampling
288-
void SampleGradientPlanner::ResamplePolicy(int horizon) {
289-
// dimensions
290-
int num_parameters = resampled_policy.num_parameters;
291-
int num_spline_points = resampled_policy.num_spline_points;
303+
void SampleGradientPlanner::ResamplePolicy(
304+
SamplingPolicy& policy, int horizon, int num_spline_points,
305+
PolicyRepresentation representation) {
306+
// dimension
307+
int num_parameters = model->nu * num_spline_points;
292308

293309
// time
294310
double nominal_time = time;
@@ -298,23 +314,27 @@ void SampleGradientPlanner::ResamplePolicy(int horizon) {
298314
// get spline points
299315
for (int t = 0; t < num_spline_points; t++) {
300316
times_scratch[t] = nominal_time;
301-
resampled_policy.Action(DataAt(parameters_scratch, t * model->nu), nullptr,
302-
nominal_time);
317+
policy.Action(DataAt(parameters_scratch, t * model->nu), nullptr,
318+
nominal_time);
303319
nominal_time += time_shift;
304320
}
305321

306322
// copy resampled policy parameters
307-
mju_copy(resampled_policy.parameters.data(), parameters_scratch.data(),
323+
mju_copy(policy.parameters.data(), parameters_scratch.data(),
308324
num_parameters);
309-
mju_copy(resampled_policy.times.data(), times_scratch.data(),
325+
mju_copy(policy.times.data(), times_scratch.data(),
310326
num_spline_points);
311327

312328
// time step linear range
313-
LinearRange(resampled_policy.times.data(), time_shift,
314-
resampled_policy.times[0], num_spline_points);
329+
LinearRange(policy.times.data(), time_shift,
330+
policy.times[0], num_spline_points);
331+
332+
// set dimensions
333+
policy.num_parameters = num_parameters;
334+
policy.num_spline_points = num_spline_points;
315335

316336
// representation
317-
resampled_policy.representation = policy.representation;
337+
policy.representation = representation;
318338
}
319339

320340
// add random noise to nominal policy
@@ -365,7 +385,7 @@ void SampleGradientPlanner::Rollouts(int num_trajectory, int num_gradient,
365385
&mocap = this->mocap, &userdata = this->userdata, horizon,
366386
idx_nominal = this->idx_nominal, num_trajectory,
367387
num_gradient, i]() {
368-
// nominal + noisy policies
388+
// nominal and noisy policies
369389
if (i < num_trajectory - num_gradient) {
370390
// copy nominal policy
371391
s.candidate_policy[i].CopyFrom(s.resampled_policy,
@@ -404,6 +424,7 @@ void SampleGradientPlanner::GradientCandidates(int num_trajectory,
404424

405425
// number of parameters
406426
int num_parameters = resampled_policy.num_parameters;
427+
int num_spline_points = resampled_policy.num_spline_points;
407428

408429
// cache old gradient
409430
mju_copy(gradient_previous.data(), gradient.data(), num_parameters);
@@ -428,7 +449,7 @@ void SampleGradientPlanner::GradientCandidates(int num_trajectory,
428449

429450
// normalize gradient
430451
// TODO(taylor): should we normalize?
431-
// mju_normalize(gradient.data(), num_parameters);
452+
mju_normalize(gradient.data(), num_parameters);
432453

433454
// compute step sizes along gradient
434455
std::vector<double> step_size(num_gradient);
@@ -442,8 +463,7 @@ void SampleGradientPlanner::GradientCandidates(int num_trajectory,
442463
// these candidates will be evaluated at the next planning iteration
443464
for (int i = num_noisy; i < num_trajectory; i++) {
444465
// copy nominal policy
445-
candidate_policy[i].CopyFrom(resampled_policy,
446-
resampled_policy.num_spline_points);
466+
candidate_policy[i].CopyFrom(resampled_policy, num_spline_points);
447467
candidate_policy[i].representation = resampled_policy.representation;
448468

449469
// scaling
@@ -453,15 +473,15 @@ void SampleGradientPlanner::GradientCandidates(int num_trajectory,
453473

454474
// gradient step
455475
mju_addToScl(candidate_policy[i].parameters.data(), gradient.data(),
456-
-scaling * gradient_filter, resampled_policy.num_parameters);
476+
-scaling * gradient_filter, num_parameters);
457477

458478
// TODO(taylor): resample the gradient_previous?
459479
mju_addToScl(candidate_policy[i].parameters.data(),
460480
gradient_previous.data(), -scaling * (1.0 - gradient_filter),
461-
resampled_policy.num_parameters);
481+
num_parameters);
462482

463483
// clamp parameters
464-
for (int t = 0; t < resampled_policy.num_spline_points; t++) {
484+
for (int t = 0; t < num_spline_points; t++) {
465485
Clamp(DataAt(candidate_policy[i].parameters, t * model->nu),
466486
model->actuator_ctrlrange, model->nu);
467487
}

mjpc/planners/sample_gradient/planner.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ class SampleGradientPlanner : public Planner {
6363
bool use_previous = false) override;
6464

6565
// resample nominal policy
66-
void ResamplePolicy(int horizon);
66+
void ResamplePolicy(SamplingPolicy& policy, int horizon,
67+
int num_spline_points,
68+
PolicyRepresentation representation);
6769

6870
// add noise to nominal policy
6971
void AddNoiseToPolicy(int i);
@@ -143,7 +145,7 @@ class SampleGradientPlanner : public Planner {
143145
double policy_update_compute_time;
144146

145147
int num_trajectory_;
146-
int num_gradient_; // number of gradient candidates
148+
int num_gradient_; // number of gradient candidates
147149
mutable std::shared_mutex mtx_;
148150

149151
// approximate gradient

0 commit comments

Comments
 (0)