Skip to content

Commit d3dc2e2

Browse files
committed
Rearrange setters and getters for reference state
1 parent 8925fb9 commit d3dc2e2

File tree

13 files changed

+85
-40
lines changed

13 files changed

+85
-40
lines changed

bindings/expose-mpc.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,14 @@ namespace simple_mpc
7979
.def("getReferencePose", &MPC::getReferencePose, bp::args("self", "t", "ee_name"))
8080
.def("setTerminalReferencePose", &MPC::setTerminalReferencePose, bp::args("self", "ee_name", "pose_ref"))
8181
.def_readwrite("velocity_base", &MPC::velocity_base_)
82-
.def_readwrite("pose_base", &MPC::pose_base_)
82+
.def_readwrite("x_reference", &MPC::x_reference_)
8383
.def_readonly("ocp_handler", &MPC::ocp_handler_)
84-
.def("setPoseBase", &MPC::setPoseBase, ("self"_a, "pose_base"))
85-
.def("getPoseBase", &MPC::getPoseBase, ("self"_a, "t"))
8684
.def("switchToWalk", &MPC::switchToWalk, ("self"_a, "velocity_base"))
8785
.def("switchToStand", &MPC::switchToStand, "self"_a)
8886
.def("getFootTakeoffCycle", &MPC::getFootTakeoffCycle, ("self"_a, "ee_name"))
8987
.def("getFootLandCycle", &MPC::getFootLandCycle, ("self"_a, "ee_name"))
9088
.def("getStateDerivative", &MPC::getStateDerivative, ("self"_a, "t"))
9189
.def("getContactForces", &MPC::getContactForces, ("self"_a, "t"))
92-
.def("setReferenceState", &MPC::setReferenceState, ("self"_a, "state_ref"))
9390
.def("getCyclingContactState", &MPC::getCyclingContactState, ("self"_a, "t", "ee_name"))
9491
.def(
9592
"getModelHandler", &MPC::getModelHandler, "self"_a, bp::return_internal_reference<>(),

bindings/expose-problem.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ namespace simple_mpc
6969
.def("getVelocityBase", bp::pure_virtual(&OCPHandler::getVelocityBase), bp::args("self", "t"))
7070
.def("setPoseBase", bp::pure_virtual(&OCPHandler::setPoseBase), bp::args("self", "t", "pose_base"))
7171
.def("getPoseBase", bp::pure_virtual(&OCPHandler::getPoseBase), bp::args("self", "t"))
72+
.def("setReferenceState", bp::pure_virtual(&OCPHandler::setReferenceState), bp::args("self", "t", "x_ref"))
73+
.def("getReferenceState", bp::pure_virtual(&OCPHandler::getReferenceState), bp::args("self", "t"))
7274
.def("getProblemState", bp::pure_virtual(&OCPHandler::getProblemState), bp::args("self", "data_handler"))
7375
.def("getContactSupport", bp::pure_virtual(&OCPHandler::getContactSupport), bp::args("self", "t"))
7476
.def("getContactState", bp::pure_virtual(&OCPHandler::getContactState), bp::args("self", "t"))
@@ -77,8 +79,6 @@ namespace simple_mpc
7779
("self"_a, "x0", "horizon", "force_size", "gravity", "terminal_constraint"))
7880
.def("setReferenceControl", &OCPHandler::setReferenceControl, ("self"_a, "t", "u_ref"))
7981
.def("getReferenceControl", &OCPHandler::getReferenceControl, ("self"_a, "t"))
80-
.def("setReferenceState", &OCPHandler::setReferenceState, ("self"_a, "x_ref"))
81-
.def("getReferenceState", &OCPHandler::getReferenceState, ("self"_a))
8282
.def("getProblem", +[](OCPHandler & ocp) { return boost::ref(ocp.getProblem()); }, "self"_a);
8383

8484
exposeContainers();

include/simple-mpc/centroidal-dynamics.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ namespace simple_mpc
9797
const Eigen::VectorXd getProblemState(const RobotDataHandler & data_handler) override;
9898
size_t getContactSupport(const std::size_t t) override;
9999
std::vector<bool> getContactState(const std::size_t t) override;
100+
void setReferenceState(const std::size_t t, const ConstVectorRef & x_ref) override;
101+
const ConstVectorRef getReferenceState(const std::size_t t) override;
100102

101103
CentroidalSettings getSettings()
102104
{

include/simple-mpc/fulldynamics.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ namespace simple_mpc
102102
const Eigen::VectorXd getProblemState(const RobotDataHandler & data_handler) override;
103103
size_t getContactSupport(const std::size_t t) override;
104104
std::vector<bool> getContactState(const std::size_t t) override;
105+
void setReferenceState(const std::size_t t, const ConstVectorRef & x_ref) override;
106+
const ConstVectorRef getReferenceState(const std::size_t t) override;
105107
FullDynamicsSettings getSettings()
106108
{
107109
return settings_;

include/simple-mpc/kinodynamics.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ namespace simple_mpc
9090
const Eigen::VectorXd getProblemState(const RobotDataHandler & data_handler) override;
9191
size_t getContactSupport(const std::size_t t) override;
9292
std::vector<bool> getContactState(const std::size_t t) override;
93+
void setReferenceState(const std::size_t t, const ConstVectorRef & x_ref) override;
94+
const ConstVectorRef getReferenceState(const std::size_t t) override;
9395

9496
void computeControlFromForces(const std::map<std::string, Eigen::VectorXd> & force_refs);
9597

include/simple-mpc/mpc.hpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ namespace simple_mpc
8888
public:
8989
std::unique_ptr<SolverProxDDP> solver_;
9090
Vector6d velocity_base_;
91-
Vector7d pose_base_;
9291
Eigen::Vector3d next_pose_;
9392
Eigen::Vector2d twist_vect_;
9493
MPCSettings settings_;
@@ -120,24 +119,11 @@ namespace simple_mpc
120119
velocity_base_ = v;
121120
}
122121

123-
void setPoseBaseFromSE3(const pin::SE3 & pose_ref)
124-
{
125-
Eigen::Map<pin::SE3::Quaternion> q{pose_base_.tail<4>().data()};
126-
pose_base_.head<3>() = pose_ref.translation();
127-
q = pose_ref.rotation();
128-
}
129-
SIMPLE_MPC_DEPRECATED void setPoseBase(const Vector7d & pose_ref)
130-
{
131-
pose_base_ = pose_ref;
132-
}
133-
134122
void setReferenceState(const VectorXd & state_ref)
135123
{
136-
ocp_handler_->setReferenceState(state_ref);
124+
x_reference_ = state_ref;
137125
}
138126

139-
ConstVectorRef getPoseBase(const std::size_t t) const;
140-
141127
// getters and setters
142128
TrajOptProblem & getTrajOptProblem();
143129

@@ -203,6 +189,7 @@ namespace simple_mpc
203189
std::vector<VectorXd> us_;
204190
// Riccati gains
205191
std::vector<MatrixXd> Ks_;
192+
VectorXd x_reference_;
206193

207194
// Initial quantities
208195
VectorXd x0_;

include/simple-mpc/ocp-handler.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ namespace simple_mpc
9696
virtual const Eigen::VectorXd getProblemState(const RobotDataHandler & data_handler) = 0;
9797
virtual size_t getContactSupport(const std::size_t t) = 0;
9898
virtual std::vector<bool> getContactState(const std::size_t t) = 0;
99+
virtual void setReferenceState(const std::size_t t, const ConstVectorRef & x_ref) = 0;
100+
virtual const ConstVectorRef getReferenceState(const std::size_t t) = 0;
99101

100102
/// Common functions for all problems
101103

@@ -111,9 +113,6 @@ namespace simple_mpc
111113
void setReferenceControl(const std::size_t t, const ConstVectorRef & u_ref);
112114
ConstVectorRef getReferenceControl(const std::size_t t);
113115

114-
void setReferenceState(const ConstVectorRef & x_ref);
115-
ConstVectorRef getReferenceState();
116-
117116
// Getter for various objects and quantities
118117
CostStack * getCostStack(std::size_t t);
119118
CostStack * getTerminalCostStack();

include/simple-mpc/python/py-ocp-handler.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,16 @@ namespace simple_mpc
170170
SIMPLE_MPC_PYTHON_OVERRIDE_PURE(std::vector<bool>, "getContactState", t);
171171
}
172172

173+
void setReferenceState(const std::size_t t, const ConstVectorRef & x_ref) override
174+
{
175+
SIMPLE_MPC_PYTHON_OVERRIDE_PURE(void, "setReferenceState", t, x_ref);
176+
}
177+
178+
const ConstVectorRef getReferenceState(const std::size_t t) override
179+
{
180+
SIMPLE_MPC_PYTHON_OVERRIDE_PURE(ConstVectorRef, "getReferenceState", t);
181+
}
182+
173183
void setReferenceControl(const std::size_t t, const ConstVectorRef & u_ref)
174184
{
175185
SIMPLE_MPC_PYTHON_OVERRIDE(void, OCPHandler, setReferenceControl, t, u_ref);

src/centroidal-dynamics.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace simple_mpc
3333
control_ref_.resize(nu_);
3434
control_ref_.setZero();
3535
com_ref_.setZero();
36+
x0_.resize(9);
3637
}
3738

3839
StageModel CentroidalOCP::createStage(
@@ -225,6 +226,7 @@ namespace simple_mpc
225226

226227
void CentroidalOCP::setVelocityBase(const std::size_t t, const ConstVectorRef & velocity_base)
227228
{
229+
assert(velocity_base.size() == 6 && "velocity_base not of the right size");
228230
CostStack * cs = getCostStack(t);
229231
QuadraticResidualCost * qcm = cs->getComponent<QuadraticResidualCost>("linear_mom_cost");
230232
LinearMomentumResidual * cfm = qcm->getResidual<LinearMomentumResidual>();
@@ -246,10 +248,7 @@ namespace simple_mpc
246248

247249
void CentroidalOCP::setPoseBase(const std::size_t t, const ConstVectorRef & pose_base)
248250
{
249-
if (pose_base.size() != 7)
250-
{
251-
throw std::runtime_error("pose_base size should be 7");
252-
}
251+
assert(pose_base.size() == 3 && "pose_base not of the right size");
253252
CostStack * cs = getCostStack(t);
254253
QuadraticResidualCost * qrc = cs->getComponent<QuadraticResidualCost>("com_cost");
255254
CentroidalCoMResidual * cfr = qrc->getResidual<CentroidalCoMResidual>();
@@ -290,6 +289,20 @@ namespace simple_mpc
290289
return contact_state;
291290
}
292291

292+
void CentroidalOCP::setReferenceState(const std::size_t t, const ConstVectorRef & x_ref)
293+
{
294+
assert(x_ref.size() == 9 && "x_ref not of the right size");
295+
setPoseBase(t, x_ref.head(3));
296+
setVelocityBase(t, x_ref.tail(6));
297+
}
298+
299+
const ConstVectorRef CentroidalOCP::getReferenceState(const std::size_t t)
300+
{
301+
x0_.head(3) = getPoseBase(t);
302+
x0_.tail(6) = getVelocityBase(t);
303+
return x0_;
304+
}
305+
293306
CostStack CentroidalOCP::createTerminalCost()
294307
{
295308
auto ter_space = VectorSpace(nx_);

src/fulldynamics.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ namespace simple_mpc
8888
auto space = MultibodyPhaseSpace(model_handler_.getModel());
8989
auto rcost = CostStack(space, nu_);
9090

91-
rcost.addCost("state_cost", QuadraticStateCost(space, nu_, x0_, settings_.w_x));
91+
rcost.addCost("state_cost", QuadraticStateCost(space, nu_, model_handler_.getReferenceState(), settings_.w_x));
9292
rcost.addCost("control_cost", QuadraticControlCost(space, Eigen::VectorXd::Zero(nu_), settings_.w_u));
9393

9494
auto cent_mom = CentroidalMomentumResidual(space.ndx(), nu_, model_handler_.getModel(), Eigen::VectorXd::Zero(6));
@@ -337,6 +337,7 @@ namespace simple_mpc
337337
}
338338
CostStack * cs = getCostStack(t);
339339
QuadraticStateCost * qc = cs->getComponent<QuadraticStateCost>("state_cost");
340+
x0_ = getReferenceState(t);
340341
x0_.segment(nq_, 6) = velocity_base;
341342
qc->setTarget(x0_);
342343
}
@@ -356,6 +357,7 @@ namespace simple_mpc
356357
}
357358
CostStack * cs = getCostStack(t);
358359
QuadraticStateCost * qc = cs->getComponent<QuadraticStateCost>("state_cost");
360+
x0_ = getReferenceState(t);
359361
x0_.head(7) = pose_base;
360362
qc->setTarget(x0_);
361363
}
@@ -398,14 +400,30 @@ namespace simple_mpc
398400
return contact_state;
399401
}
400402

403+
void FullDynamicsOCP::setReferenceState(const std::size_t t, const ConstVectorRef & x_ref)
404+
{
405+
assert(x_ref.size() == nq_ + nv_ && "x_ref not of the right size");
406+
CostStack * cs = getCostStack(t);
407+
QuadraticStateCost * qc = cs->getComponent<QuadraticStateCost>("state_cost");
408+
qc->setTarget(x_ref);
409+
}
410+
411+
const ConstVectorRef FullDynamicsOCP::getReferenceState(const std::size_t t)
412+
{
413+
CostStack * cs = getCostStack(t);
414+
QuadraticStateCost * qc = cs->getComponent<QuadraticStateCost>("state_cost");
415+
return qc->getTarget();
416+
}
417+
401418
CostStack FullDynamicsOCP::createTerminalCost()
402419
{
403420
auto ter_space = MultibodyPhaseSpace(model_handler_.getModel());
404421
auto term_cost = CostStack(ter_space, nu_);
405422
auto cent_mom =
406423
CentroidalMomentumResidual(ter_space.ndx(), nu_, model_handler_.getModel(), Eigen::VectorXd::Zero(6));
407424

408-
term_cost.addCost("state_cost", QuadraticStateCost(ter_space, nu_, x0_, settings_.w_x));
425+
term_cost.addCost(
426+
"state_cost", QuadraticStateCost(ter_space, nu_, model_handler_.getReferenceState(), settings_.w_x));
409427
/* term_cost.addCost(
410428
"centroidal_cost",
411429
QuadraticResidualCost(ter_space, cent_mom, settings_.w_cent)); */

0 commit comments

Comments
 (0)