diff --git a/bindings/expose-mpc.cpp b/bindings/expose-mpc.cpp index 65bcba54..131409a1 100644 --- a/bindings/expose-mpc.cpp +++ b/bindings/expose-mpc.cpp @@ -88,6 +88,7 @@ namespace simple_mpc .def("getFootTakeoffCycle", &MPC::getFootTakeoffCycle, ("self"_a, "ee_name")) .def("getFootLandCycle", &MPC::getFootLandCycle, ("self"_a, "ee_name")) .def("getStateDerivative", &MPC::getStateDerivative, ("self"_a, "t")) + .def("getContactForces", &MPC::getContactForces, ("self"_a, "t")) .def("getCyclingContactState", &MPC::getCyclingContactState, ("self"_a, "t", "ee_name")) .def( "getModelHandler", &MPC::getModelHandler, "self"_a, bp::return_internal_reference<>(), diff --git a/examples/go2_fulldynamics.py b/examples/go2_fulldynamics.py index c754cea6..2ce70547 100644 --- a/examples/go2_fulldynamics.py +++ b/examples/go2_fulldynamics.py @@ -201,7 +201,7 @@ L_measured = [] v = np.zeros(6) -v[0] = 0. +v[0] = 0.2 mpc.velocity_base = v for t in range(500): print("Time " + str(t)) @@ -214,7 +214,6 @@ str(land_RF) + ", takeoff_LF = " + str(takeoff_LF) + ", landing_LF = ", str(land_LF), ) """ - """ if t == 200: for s in range(T): device.resetState(mpc.xs[s][:nq]) @@ -238,15 +237,16 @@ a0 = mpc.getStateDerivative(0)[nv:] a1 = mpc.getStateDerivative(1)[nv:] - FL_f, FR_f, RL_f, RR_f = extract_forces(mpc.getTrajOptProblem(), mpc.solver.workspace, 0) + forces_vec0 = mpc.getContactForces(0) + forces_vec1 = mpc.getContactForces(1) contact_states = mpc.ocp_handler.getContactState(0) - total_forces = np.concatenate((FL_f, FR_f, RL_f, RR_f)) - force_FL.append(FL_f) - force_FR.append(FR_f) - force_RL.append(RL_f) - force_RR.append(RR_f) - forces = [total_forces, total_forces] + force_FL.append(forces_vec0[:3]) + force_FR.append(forces_vec0[3:6]) + force_RL.append(forces_vec0[6:9]) + force_RR.append(forces_vec0[9:12]) + + forces = [forces_vec0, forces_vec1] ddqs = [a0, a1] xss = [mpc.xs[0], mpc.xs[1]] uss = [mpc.us[0], mpc.us[1]] diff --git a/include/simple-mpc/mpc.hpp b/include/simple-mpc/mpc.hpp index 03a5e6de..a3e09a77 100644 --- a/include/simple-mpc/mpc.hpp +++ b/include/simple-mpc/mpc.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include "simple-mpc/deprecated.hpp" @@ -22,6 +23,8 @@ namespace simple_mpc { using ExplicitIntegratorData = dynamics::ExplicitIntegratorDataTpl; + using MultibodyConstraintFwdData = dynamics::MultibodyConstraintFwdDataTpl; + using MultibodyConstraintFwdDynamics = dynamics::MultibodyConstraintFwdDynamicsTpl; struct MPCSettings { @@ -175,13 +178,13 @@ namespace simple_mpc } } - const ConstVectorRef getStateDerivative(const std::size_t t) - { - ExplicitIntegratorData * int_data = - dynamic_cast(&*solver_->workspace_.problem_data.stage_data[t]->dynamics_data); - assert(int_data != nullptr); - return int_data->continuous_data->xdot_; - } + const ConstVectorRef getStateDerivative(const std::size_t t); + + /** + * @brief Return contact forces for a full dynamics MPC problem + * @warning Only work with fulldynamics OCP handler + */ + const Eigen::VectorXd getContactForces(const std::size_t t); void switchToWalk(const Vector6d & velocity_base); diff --git a/src/mpc.cpp b/src/mpc.cpp index 74d36af0..917fcd12 100644 --- a/src/mpc.cpp +++ b/src/mpc.cpp @@ -342,6 +342,43 @@ namespace simple_mpc return ocp_handler_->getProblem(); } + const ConstVectorRef MPC::getStateDerivative(const std::size_t t) + { + ExplicitIntegratorData * int_data = + dynamic_cast(&*solver_->workspace_.problem_data.stage_data[t]->dynamics_data); + assert(int_data != nullptr); + return int_data->continuous_data->xdot_; + } + + const Eigen::VectorXd MPC::getContactForces(const std::size_t t) + { + Eigen::VectorXd contact_forces; + contact_forces.resize(3 * (long)ee_names_.size()); + + ExplicitIntegratorData * int_data = + dynamic_cast(&*solver_->workspace_.problem_data.stage_data[t]->dynamics_data); + assert(int_data != nullptr); + MultibodyConstraintFwdData * mc_data = dynamic_cast(&*int_data->continuous_data); + assert(mc_data != nullptr); + + std::vector contact_state = ocp_handler_->getContactState(t); + + size_t force_id = 0; + for (size_t i = 0; i < contact_state.size(); i++) + { + if (contact_state[i]) + { + contact_forces.segment((long)i * 3, 3) = mc_data->constraint_datas_[force_id].contact_force.linear(); + force_id += 1; + } + else + { + contact_forces.segment((long)i * 3, 3).setZero(); + } + } + return contact_forces; + } + void MPC::switchToWalk(const Vector6d & velocity_base) { now_ = WALKING; diff --git a/tests/mpc.cpp b/tests/mpc.cpp index 4143fc1e..e021c009 100644 --- a/tests/mpc.cpp +++ b/tests/mpc.cpp @@ -90,6 +90,8 @@ BOOST_AUTO_TEST_CASE(mpc_fulldynamics) BOOST_CHECK_EQUAL(mpc.foot_land_times_.at("right_sole_link")[0], 150); Eigen::VectorXd xdot = mpc.getStateDerivative(0); + + Eigen::VectorXd forces = mpc.getContactForces(0); } BOOST_AUTO_TEST_CASE(mpc_kinodynamics)