Skip to content

Commit 8764eaa

Browse files
authored
Merge pull request #259 from thowell/num_parameters
Planner::NumParameters
2 parents 744912d + 958db40 commit 8764eaa

File tree

9 files changed

+37
-6
lines changed

9 files changed

+37
-6
lines changed

mjpc/planners/cross_entropy/planner.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ class CrossEntropyPlanner : public Planner {
8484
void Plots(mjvFigure* fig_planner, mjvFigure* fig_timer, int planner_shift,
8585
int timer_shift, int planning, int* shift) override;
8686

87+
// return number of parameters optimized by planner
88+
int NumParameters() override {
89+
return policy.num_spline_points * policy.model->nu;
90+
};
91+
8792
// ----- members ----- //
8893
mjModel* model;
8994
const Task* task;

mjpc/planners/gradient/planner.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ class GradientPlanner : public Planner {
8888
void Plots(mjvFigure* fig_planner, mjvFigure* fig_timer, int planner_shift,
8989
int timer_shift, int planning, int* shift) override;
9090

91+
// return number of parameters optimized by planner
92+
int NumParameters() override {
93+
return policy.num_spline_points * policy.model->nu;
94+
};
95+
9196
// ----- members ----- //
9297
mjModel* model;
9398
const Task* task;

mjpc/planners/ilqg/planner.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ void iLQGPlanner::NominalTrajectory(int horizon, ThreadPool& pool) {
165165
if (num_trajectory_ == 0) {
166166
return;
167167
}
168+
// set policy trajectory horizon
169+
policy.trajectory.horizon = horizon;
170+
168171
// resize data for rollouts
169172
ResizeMjData(model, pool.NumThreads());
170173

mjpc/planners/ilqg/planner.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ class iLQGPlanner : public Planner {
7171
void Plots(mjvFigure* fig_planner, mjvFigure* fig_timer, int planner_shift,
7272
int timer_shift, int planning, int* shift) override;
7373

74+
// return number of parameters optimized by planner
75+
int NumParameters() override {
76+
return policy.trajectory.dim_action * (policy.trajectory.horizon - 1);
77+
};
78+
7479
// single iLQG iteration
7580
void Iteration(int horizon, ThreadPool& pool);
7681

@@ -85,8 +90,6 @@ class iLQGPlanner : public Planner {
8590

8691
void UpdateNumTrajectoriesFromGUI();
8792

88-
//
89-
9093
// ----- members ----- //
9194
mjModel* model;
9295
const Task* task;

mjpc/planners/ilqs/planner.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ class iLQSPlanner : public Planner {
8181
void Plots(mjvFigure* fig_planner, mjvFigure* fig_timer, int planner_shift,
8282
int timer_shift, int planning, int* shift) override;
8383

84+
// return number of parameters optimized by planner
85+
int NumParameters() override {
86+
return sampling.NumParameters() + ilqg.NumParameters();
87+
};
88+
8489
// ----- planners ----- //
8590
SamplingPlanner sampling;
8691
iLQGPlanner ilqg;

mjpc/planners/planner.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ class Planner {
7171
int planner_shift, int timer_shift, int planning,
7272
int* shift) = 0;
7373

74+
// return number of parameters optimized by planner
75+
virtual int NumParameters() = 0;
76+
7477
std::vector<UniqueMjData> data_;
7578
void ResizeMjData(const mjModel* model, int num_threads);
7679
};

mjpc/planners/robust/robust_planner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class RobustPlanner : public Planner {
5454
void GUI(mjUI& ui) override;
5555
void Plots(mjvFigure* fig_planner, mjvFigure* fig_timer, int planner_shift,
5656
int timer_shift, int planning, int* shift) override;
57+
int NumParameters() override { return delegate_->NumParameters(); };
5758

5859
private:
5960
void ResizeTrajectories(int ntrajectories);

mjpc/planners/sampling/planner.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ class SamplingPlanner : public RankedPlanner {
9191
void Plots(mjvFigure* fig_planner, mjvFigure* fig_timer, int planner_shift,
9292
int timer_shift, int planning, int* shift) override;
9393

94+
// return number of parameters optimized by planner
95+
int NumParameters() override {
96+
return policy.num_spline_points * policy.model->nu;
97+
};
98+
9499
// optimizes policies, but rather than picking the best, generate up to
95100
// ncandidates. returns number of candidates created.
96101
int OptimizePolicyCandidates(int ncandidates, int horizon,

mjpc/simulate.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,13 +526,14 @@ void UpdateInfoText(mj::Simulate* sim,
526526
solerr = mju_log10(mju_max(mjMINVAL, solerr));
527527

528528
// prepare info text
529-
mju::strcpy_arr(title, "Objective\nDoFs\nControls\nTime\nMemory");
529+
mju::strcpy_arr(title, "Objective\nDoFs\nControls\nParameters\nTime\nMemory");
530530
const mjpc::Trajectory* best_trajectory =
531531
sim->agent->ActivePlanner().BestTrajectory();
532532
if (best_trajectory) {
533-
mju::sprintf_arr(content, "%.3f\n%d\n%d\n%-9.3f\n%.2g of %s",
534-
best_trajectory->total_return, m->nv, m->nu, d->time,
535-
d->maxuse_arena / (double)(d->narena),
533+
int nparam = sim->agent->ActivePlanner().NumParameters();
534+
mju::sprintf_arr(content, "%.3f\n%d\n%d\n%d\n%-9.3f\n%.2g of %s",
535+
best_trajectory->total_return, m->nv, m->nu, nparam,
536+
d->time, d->maxuse_arena / (double)(d->narena),
536537
mju_writeNumBytes(d->narena));
537538
}
538539

0 commit comments

Comments
 (0)