Skip to content

Commit d035b87

Browse files
committed
Working arm mpc
1 parent f9eb135 commit d035b87

File tree

8 files changed

+120
-56
lines changed

8 files changed

+120
-56
lines changed

bindings/expose-arm-dynamics.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1+
///////////////////////////////////////////////////////////////////////////////
2+
// BSD 2-Clause License
3+
//
4+
// Copyright (C) 2025, INRIA
5+
// Copyright note valid unless otherwise stated in individual files.
6+
// All rights reserved.
7+
///////////////////////////////////////////////////////////////////////////////
8+
19
#include "simple-mpc/arm-dynamics.hpp"
210
#include "simple-mpc/python.hpp"
311

12+
#include "simple-mpc/fwd.hpp"
13+
#include <eigenpy/eigenpy.hpp>
414
#include <eigenpy/std-map.hpp>
515

616
namespace simple_mpc::python
@@ -59,7 +69,13 @@ namespace simple_mpc::python
5969
"__init__",
6070
bp::make_constructor(&createArmDynamics, bp::default_call_policies(), ("settings"_a, "model_handler")))
6171
.def("getSettings", &getSettingsArm)
62-
.def("createStage", &ArmDynamicsOCP::createStage, bp::args("self", "reaching", "reach_pose"));
72+
.def("createStage", &ArmDynamicsOCP::createStage, bp::args("self", "reaching", "reach_pose"))
73+
.def("createProblem", &ArmDynamicsOCP::createProblem, ("self"_a, "x0", "horizon"))
74+
.def("setReferencePose", &ArmDynamicsOCP::setReferencePose, bp::args("self", "t", "pose_ref"))
75+
.def("getReferencePose", &ArmDynamicsOCP::getReferencePose, bp::args("self", "t"))
76+
.def("setReferenceState", &ArmDynamicsOCP::setReferenceState, bp::args("self", "t", "x_ref"))
77+
.def("getReferenceState", &ArmDynamicsOCP::getReferenceState, bp::args("self", "t"))
78+
.def("getProblem", +[](ArmDynamicsOCP & ocp) { return boost::ref(ocp.getProblem()); }, "self"_a);
6379
}
6480

6581
} // namespace simple_mpc::python

bindings/expose-arm-mpc.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
///////////////////////////////////////////////////////////////////////////////
22
// BSD 2-Clause License
33
//
4-
// Copyright (C) 2024, INRIA
4+
// Copyright (C) 2025, INRIA
55
// Copyright note valid unless otherwise stated in individual files.
66
// All rights reserved.
77
///////////////////////////////////////////////////////////////////////////////
@@ -60,7 +60,6 @@ namespace simple_mpc
6060
bp::class_<ArmMPC, boost::noncopyable>("ArmMPC", bp::no_init)
6161
.def("__init__", bp::make_constructor(&createArmMPC, bp::default_call_policies()))
6262
.def("getSettings", &getSettings)
63-
.def("generateReachHorizon", &ArmMPC::generateReachHorizon, bp::args("self", "reach_pose"))
6463
.def("iterate", &ArmMPC::iterate, bp::args("self", "x"))
6564
.def("setReferencePose", &ArmMPC::setReferencePose, bp::args("self", "t", "pose_ref"))
6665
.def("getReferencePose", &ArmMPC::getReferencePose, bp::args("self", "t"))
@@ -78,9 +77,6 @@ namespace simple_mpc
7877
.def(
7978
"getTrajOptProblem", &ArmMPC::getTrajOptProblem, "self"_a, bp::return_internal_reference<>(),
8079
"Get the trajectory optimal problem.")
81-
.def(
82-
"getReachHorizon", &ArmMPC::getReachHorizon, "self"_a, bp::return_internal_reference<>(),
83-
"Get the reach horizon.")
8480
.add_property("solver", bp::make_getter(&ArmMPC::solver_, eigenpy::ReturnInternalStdUniquePtr{}))
8581
.add_property("xs", &ArmMPC::xs_)
8682
.add_property("us", &ArmMPC::us_)

include/simple-mpc/arm-dynamics.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,21 @@ namespace simple_mpc
7676

7777
// Getters and setters
7878
CostStack * getCostStack(std::size_t t);
79+
CostStack * getTerminalCostStack();
7980
void deactivateReach(const std::size_t t);
8081
void activateReach(const std::size_t t);
8182
void setReferencePose(const std::size_t t, const Eigen::Vector3d & pose_ref);
8283
const Eigen::Vector3d getReferencePose(const std::size_t t);
84+
void setTerminalReferencePose(const Eigen::Vector3d & pose_ref);
85+
const Eigen::Vector3d getTerminalReferencePose();
8386
const Eigen::VectorXd getProblemState(const RobotDataHandler & data_handler);
8487
void setReferenceState(const std::size_t t, const ConstVectorRef & x_ref);
8588
const ConstVectorRef getReferenceState(const std::size_t t);
89+
void setWeight(const std::size_t t, const std::string key, double weight);
90+
double getWeight(const std::size_t t, const std::string key);
91+
void setTerminalWeight(const std::string key, double weight);
92+
double getTerminalWeight(const std::string key);
93+
8694
ArmDynamicsSettings getSettings()
8795
{
8896
return settings_;

include/simple-mpc/arm-mpc.hpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ namespace simple_mpc
5252
REACHING
5353
};
5454

55-
std::vector<std::shared_ptr<StageModel>> reach_horizon_;
56-
std::vector<std::shared_ptr<StageData>> reach_horizon_data_;
57-
std::vector<std::shared_ptr<StageModel>> rest_horizon_;
58-
std::vector<std::shared_ptr<StageData>> rest_horizon_data_;
5955
// INTERNAL UPDATING function
6056
void updateTargetReference();
6157

@@ -106,11 +102,6 @@ namespace simple_mpc
106102
return ocp_handler_->getModelHandler();
107103
}
108104

109-
std::vector<std::shared_ptr<StageModel>> & getReachHorizon()
110-
{
111-
return reach_horizon_;
112-
}
113-
114105
const ConstVectorRef getStateDerivative(const std::size_t t);
115106

116107
void switchToReach(const Eigen::Vector3d & reach_pose);

include/simple-mpc/fwd.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ namespace simple_mpc
3535
class FullDynamicsOCP;
3636
class KinodynamicsOCP;
3737
class CentroidalOCP;
38+
class ArmDynamicsOCP;
3839
class OCPHandler;
3940
class IDSolver;
4041
class IKIDSolver;

src/arm-dynamics.cpp

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,18 @@ namespace simple_mpc
3636
auto space = MultibodyPhaseSpace(model_handler_.getModel());
3737
auto rcost = CostStack(space, nv_);
3838

39-
rcost.addCost("state_cost", QuadraticStateCost(space, nv_, model_handler_.getReferenceState(), settings_.w_x));
39+
rcost.addCost("state_cost", QuadraticStateCost(space, nv_, x0_, settings_.w_x));
4040
rcost.addCost("control_cost", QuadraticControlCost(space, Eigen::VectorXd::Zero(nv_), settings_.w_u));
4141

4242
FrameTranslationResidual frame_residual =
4343
FrameTranslationResidual(space.ndx(), nv_, model_handler_.getModel(), reach_pose, ee_id_);
4444
if (reaching)
4545
{
46-
rcost.addCost(settings_.ee_name + "_cost", QuadraticResidualCost(space, frame_residual, settings_.w_frame));
46+
rcost.addCost("frame_cost", QuadraticResidualCost(space, frame_residual, settings_.w_frame));
4747
}
4848
else
4949
{
50-
rcost.addCost(settings_.ee_name + "_cost", QuadraticResidualCost(space, frame_residual, settings_.w_frame), 0.);
50+
rcost.addCost("frame_cost", QuadraticResidualCost(space, frame_residual, settings_.w_frame), 0.);
5151
}
5252

5353
MultibodyFreeFwdDynamics ode = MultibodyFreeFwdDynamics(space, Eigen::MatrixXd::Identity(nv_, nv_));
@@ -63,8 +63,14 @@ namespace simple_mpc
6363
}
6464
if (settings_.kinematics_limits)
6565
{
66+
std::vector<int> state_id;
67+
for (int i = 0; i < nv_; i++)
68+
{
69+
state_id.push_back(i);
70+
}
6671
StateErrorResidual state_fn = StateErrorResidual(space, nv_, space.neutral());
67-
stm.addConstraint(state_fn, BoxConstraint(settings_.qmin, settings_.qmax));
72+
FunctionSliceXpr state_slice = FunctionSliceXpr(state_fn, state_id);
73+
stm.addConstraint(state_slice, BoxConstraint(settings_.qmin, settings_.qmax));
6874
}
6975

7076
return stm;
@@ -81,18 +87,43 @@ namespace simple_mpc
8187
return cs;
8288
}
8389

90+
CostStack * ArmDynamicsOCP::getTerminalCostStack()
91+
{
92+
CostStack * cs = dynamic_cast<CostStack *>(&*problem_->term_cost_);
93+
94+
return cs;
95+
}
96+
8497
void ArmDynamicsOCP::setReferencePose(const std::size_t t, const Eigen::Vector3d & pose_ref)
8598
{
8699
CostStack * cs = getCostStack(t);
87-
QuadraticResidualCost * qrc = cs->getComponent<QuadraticResidualCost>(settings_.ee_name + "_cost");
100+
QuadraticResidualCost * qrc = cs->getComponent<QuadraticResidualCost>("frame_cost");
88101
FrameTranslationResidual * cfr = qrc->getResidual<FrameTranslationResidual>();
89102
cfr->setReference(pose_ref);
90103
}
91104

92105
const Eigen::Vector3d ArmDynamicsOCP::getReferencePose(const std::size_t t)
93106
{
94107
CostStack * cs = getCostStack(t);
95-
QuadraticResidualCost * qrc = cs->getComponent<QuadraticResidualCost>(settings_.ee_name + "_cost");
108+
QuadraticResidualCost * qrc = cs->getComponent<QuadraticResidualCost>("frame_cost");
109+
FrameTranslationResidual * cfr = qrc->getResidual<FrameTranslationResidual>();
110+
Eigen::Vector3d ref = cfr->getReference();
111+
112+
return ref;
113+
}
114+
115+
void ArmDynamicsOCP::setTerminalReferencePose(const Eigen::Vector3d & pose_ref)
116+
{
117+
CostStack * cs = getTerminalCostStack();
118+
QuadraticResidualCost * qrc = cs->getComponent<QuadraticResidualCost>("frame_cost");
119+
FrameTranslationResidual * cfr = qrc->getResidual<FrameTranslationResidual>();
120+
cfr->setReference(pose_ref);
121+
}
122+
123+
const Eigen::Vector3d ArmDynamicsOCP::getTerminalReferencePose()
124+
{
125+
CostStack * cs = getTerminalCostStack();
126+
QuadraticResidualCost * qrc = cs->getComponent<QuadraticResidualCost>("frame_cost");
96127
FrameTranslationResidual * cfr = qrc->getResidual<FrameTranslationResidual>();
97128
Eigen::Vector3d ref = cfr->getReference();
98129

@@ -119,6 +150,30 @@ namespace simple_mpc
119150
return qc->getTarget();
120151
}
121152

153+
void ArmDynamicsOCP::setTerminalWeight(const std::string key, double weight)
154+
{
155+
CostStack * cs = getTerminalCostStack();
156+
cs->setWeight(key, weight);
157+
}
158+
159+
double ArmDynamicsOCP::getTerminalWeight(const std::string key)
160+
{
161+
CostStack * cs = getTerminalCostStack();
162+
return cs->getWeight(key);
163+
}
164+
165+
void ArmDynamicsOCP::setWeight(const std::size_t t, const std::string key, double weight)
166+
{
167+
CostStack * cs = getCostStack(t);
168+
cs->setWeight(key, weight);
169+
}
170+
171+
double ArmDynamicsOCP::getWeight(const std::size_t t, const std::string key)
172+
{
173+
CostStack * cs = getCostStack(t);
174+
return cs->getWeight(key);
175+
}
176+
122177
CostStack ArmDynamicsOCP::createTerminalCost()
123178
{
124179
auto ter_space = MultibodyPhaseSpace(model_handler_.getModel());
@@ -127,6 +182,10 @@ namespace simple_mpc
127182
term_cost.addCost(
128183
"state_cost", QuadraticStateCost(ter_space, nv_, model_handler_.getReferenceState(), settings_.w_x));
129184

185+
FrameTranslationResidual frame_residual =
186+
FrameTranslationResidual(ter_space.ndx(), nv_, model_handler_.getModel(), Eigen::Vector3d::Zero(), ee_id_);
187+
term_cost.addCost("frame_cost", QuadraticResidualCost(ter_space, frame_residual, settings_.w_frame), 0.0);
188+
130189
return term_cost;
131190
}
132191

src/arm-mpc.cpp

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -36,42 +36,26 @@ namespace simple_mpc
3636
else
3737
solver_->linear_solver_choice = aligator::LQSolverChoice::SERIAL;
3838
solver_->force_initial_condition_ = true;
39-
// solver_->reg_min = 1e-6;
39+
solver_->reg_min = 1e-6;
4040

4141
ee_name_ = problem->getSettings().ee_name;
4242

4343
for (std::size_t i = 0; i < ocp_handler_->getProblem().numSteps(); i++)
4444
{
4545
xs_.push_back(x0_);
4646
us_.push_back(Eigen::VectorXd::Zero(model_handler.getModel().nv));
47-
48-
std::shared_ptr<StageModel> sm = std::make_shared<StageModel>(ocp_handler_->createStage());
49-
rest_horizon_.push_back(sm);
50-
rest_horizon_data_.push_back(sm->createData());
5147
}
5248
xs_.push_back(x0_);
5349

5450
solver_->setup(ocp_handler_->getProblem());
5551
solver_->run(ocp_handler_->getProblem(), xs_, us_);
5652

57-
/*xs_ = solver_->results_.xs;
53+
xs_ = solver_->results_.xs;
5854
us_ = solver_->results_.us;
5955
Ks_ = solver_->results_.getCtrlFeedbacks();
6056

61-
solver_->max_iters = settings_.max_iters; */
62-
}
63-
64-
void ArmMPC::generateReachHorizon(const Eigen::Vector3d & reach_pose)
65-
{
66-
reach_pose_ = reach_pose;
67-
// Generate the model stages for cycle horizon
68-
for (std::size_t i = 0; i < ocp_handler_->getProblem().numSteps(); i++)
69-
{
70-
71-
std::shared_ptr<StageModel> sm = std::make_shared<StageModel>(ocp_handler_->createStage(true, reach_pose));
72-
reach_horizon_.push_back(sm);
73-
reach_horizon_data_.push_back(sm->createData());
74-
}
57+
solver_->max_iters = settings_.max_iters;
58+
now_ = RESTING;
7559
}
7660

7761
void ArmMPC::iterate(const ConstVectorRef & x)
@@ -107,33 +91,46 @@ namespace simple_mpc
10791

10892
void ArmMPC::recedeWithCycle()
10993
{
94+
std::size_t last_id = ocp_handler_->getSize() - 1;
95+
rotate_vec_left(ocp_handler_->getProblem().stages_);
96+
11097
if (now_ == REACHING)
11198
{
112-
ocp_handler_->getProblem().replaceStageCircular(*reach_horizon_[0]);
113-
solver_->cycleProblem(ocp_handler_->getProblem(), reach_horizon_data_[0]);
114-
115-
rotate_vec_left(reach_horizon_);
116-
rotate_vec_left(reach_horizon_data_);
99+
ocp_handler_->setWeight(last_id, "frame_cost", 1.0);
100+
ocp_handler_->setTerminalWeight("frame_cost", 1.0);
117101
}
118102
else
119103
{
120-
ocp_handler_->getProblem().replaceStageCircular(*rest_horizon_[0]);
121-
solver_->cycleProblem(ocp_handler_->getProblem(), rest_horizon_data_[0]);
122-
123-
rotate_vec_left(rest_horizon_);
124-
rotate_vec_left(rest_horizon_data_);
104+
ocp_handler_->setWeight(last_id, "frame_cost", 0.0);
105+
ocp_handler_->setTerminalWeight("frame_cost", 0.0);
125106
}
126107
}
127108

128109
void ArmMPC::updateTargetReference()
129110
{
130111
ocp_handler_->setReferencePose(ocp_handler_->getSize() - 1, reach_pose_);
112+
ocp_handler_->setTerminalReferencePose(reach_pose_);
131113
ocp_handler_->setReferenceState(ocp_handler_->getSize() - 1, x_reference_);
132114
}
133115

116+
void ArmMPC::setReferencePose(const std::size_t t, const Eigen::Vector3d & pose_ref)
117+
{
118+
if (t < ocp_handler_->getSize() - 1)
119+
ocp_handler_->setReferencePose(t, pose_ref);
120+
else
121+
ocp_handler_->setTerminalReferencePose(pose_ref);
122+
;
123+
}
124+
134125
const Eigen::Vector3d ArmMPC::getReferencePose(const std::size_t t) const
135126
{
136-
return ocp_handler_->getReferencePose(t);
127+
Eigen::Vector3d pos_ref;
128+
if (t < ocp_handler_->getSize() - 1)
129+
pos_ref = ocp_handler_->getReferencePose(t);
130+
else
131+
pos_ref = ocp_handler_->getTerminalReferencePose();
132+
133+
return pos_ref;
137134
}
138135

139136
TrajOptProblem & ArmMPC::getTrajOptProblem()

tests/mpc.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,6 @@ BOOST_AUTO_TEST_CASE(mpc_armdynamics)
272272
BOOST_CHECK_EQUAL(mpc.xs_.size(), T + 1);
273273
BOOST_CHECK_EQUAL(mpc.us_.size(), T);
274274

275-
Eigen::Vector3d reach_pose;
276-
reach_pose << 0.5, 0.5, 0.5;
277-
mpc.generateReachHorizon(reach_pose);
278-
279275
for (std::size_t i = 0; i < 10; i++)
280276
{
281277
mpc.iterate(model_handler.getReferenceState());

0 commit comments

Comments
 (0)