From abead7403332f5644c2d1b0b8cf1c94719473d2a Mon Sep 17 00:00:00 2001 From: taylor howell Date: Wed, 20 Mar 2024 09:22:49 -0400 Subject: [PATCH 1/2] separate planner and estimator model loading --- mjpc/agent.cc | 14 ++++- mjpc/agent.h | 8 ++- mjpc/app.cc | 39 ++++++++++++- mjpc/task.h | 2 + mjpc/tasks/CMakeLists.txt | 3 + mjpc/tasks/particle/particle.cc | 5 +- mjpc/tasks/particle/particle.h | 1 + mjpc/tasks/particle/particle.xml | 49 ++++++++++++++++ mjpc/tasks/particle/particle_red.xml.patch | 57 +++++++++++++++++++ ...evarying.xml => task_timevarying_plan.xml} | 0 mjpc/tasks/particle/task_timevarying_task.xml | 46 +++++++++++++++ 11 files changed, 217 insertions(+), 7 deletions(-) create mode 100644 mjpc/tasks/particle/particle.xml create mode 100644 mjpc/tasks/particle/particle_red.xml.patch rename mjpc/tasks/particle/{task_timevarying.xml => task_timevarying_plan.xml} (100%) create mode 100644 mjpc/tasks/particle/task_timevarying_task.xml diff --git a/mjpc/agent.cc b/mjpc/agent.cc index 65818a15a..2d9d5c16f 100644 --- a/mjpc/agent.cc +++ b/mjpc/agent.cc @@ -68,11 +68,21 @@ Agent::Agent(const mjModel* model, std::shared_ptr task) } // initialize data, settings, planners, state -void Agent::Initialize(const mjModel* model) { +void Agent::Initialize(const mjModel* model, const mjModel* estimator_model) { // ----- model ----- // + + // planner model if (model_) mj_deleteModel(model_); model_ = mj_copyModel(nullptr, model); // agent's copy of model + // estimator model + if (model_estimator_) mj_deleteModel(model_estimator_); + if (estimator_model) { + model_estimator_ = mj_copyModel(nullptr, estimator_model); + } else { + model_estimator_ = mj_copyModel(nullptr, model); + } + // check for limits on all actuators int num_missing = 0; for (int i = 0; i < model_->nu; i++) { @@ -120,7 +130,7 @@ void Agent::Initialize(const mjModel* model) { // initialize estimator if (reset_estimator && estimator_enabled) { for (const auto& estimator : estimators_) { - estimator->Initialize(model_); + estimator->Initialize(model_estimator_); estimator->Reset(); } } diff --git a/mjpc/agent.h b/mjpc/agent.h index bce4cfbc9..10403e77a 100644 --- a/mjpc/agent.h +++ b/mjpc/agent.h @@ -54,12 +54,15 @@ class Agent { // destructor ~Agent() { if (model_) mj_deleteModel(model_); // we made a copy in Initialize + if (model_estimator_) + mj_deleteModel(model_estimator_); // we made a copy in Initialize } // ----- methods ----- // // initialize data, settings, planners, states - void Initialize(const mjModel* model); + void Initialize(const mjModel* model, + const mjModel* estimator_model = nullptr); // allocate memory void Allocate(); @@ -120,6 +123,8 @@ class Agent { std::string GetTaskNames() const { return task_names_; } int GetTaskIdByName(std::string_view name) const; std::string GetTaskXmlPath(int id) const { return tasks_[id]->XmlPath(); } + std::string GetPlannerXmlPath(int id) const { return tasks_[id]->PlannerXmlPath(); } + std::string GetEstimatorXmlPath(int id) const { return tasks_[id]->EstimatorXmlPath(); } // load the latest task model, based on GUI settings struct LoadModelResult { @@ -190,6 +195,7 @@ class Agent { private: // model mjModel* model_ = nullptr; + mjModel* model_estimator_ = nullptr; UniqueMjModel model_override_ = {nullptr, mj_deleteModel}; diff --git a/mjpc/app.cc b/mjpc/app.cc index a71590803..d28fec90e 100644 --- a/mjpc/app.cc +++ b/mjpc/app.cc @@ -222,6 +222,23 @@ void PhysicsLoop(mj::Simulate& sim) { // ----- task reload ----- // if (sim.uiloadrequest.load() == 1) { + // get new estimator model + sim.filename = sim.agent->GetEstimatorXmlPath(sim.agent->gui_task_id); + + mjModel* mest = nullptr; + if (!sim.filename.empty()) { + mest = LoadModel(sim.agent.get(), sim); + } + + // get new planner model + sim.filename = sim.agent->GetPlannerXmlPath(sim.agent->gui_task_id); + + mjModel* mplan = nullptr; + if (!sim.filename.empty()) { + mplan = LoadModel(sim.agent.get(), sim); + sim.agent->Initialize(mplan, mest); + } + // get new model + task sim.filename = sim.agent->GetTaskXmlPath(sim.agent->gui_task_id); @@ -229,7 +246,7 @@ void PhysicsLoop(mj::Simulate& sim) { mjData* dnew = nullptr; if (mnew) dnew = mj_makeData(mnew); if (dnew) { - sim.agent->Initialize(mnew); + if (mplan == nullptr) sim.agent->Initialize(mnew, mest); sim.agent->plot_enabled = absl::GetFlag(FLAGS_show_plot); sim.agent->plan_enabled = absl::GetFlag(FLAGS_planner_enabled); sim.agent->Allocate(); @@ -428,6 +445,23 @@ MjpcApp::MjpcApp(std::vector> tasks) { } } + // estimator setup + sim->agent->estimator_enabled = absl::GetFlag(FLAGS_estimator_enabled); + mjModel* mest = nullptr; + sim->filename = sim->agent->GetEstimatorXmlPath(sim->agent->gui_task_id); + if (!sim->filename.empty()) { + mest = LoadModel(sim->agent.get(), *sim); + } + + // load planner model + mjModel* mplan = nullptr; + sim->filename = sim->agent->GetPlannerXmlPath(sim->agent->gui_task_id); + if (!sim->filename.empty()) { + mplan = LoadModel(sim->agent.get(), *sim); + sim->agent->Initialize(mplan, mest); + } + + // load task model sim->filename = sim->agent->GetTaskXmlPath(sim->agent->gui_task_id); m = LoadModel(sim->agent.get(), *sim); if (m) d = mj_makeData(m); @@ -445,8 +479,7 @@ MjpcApp::MjpcApp(std::vector> tasks) { mju_zero(ctrlnoise, m->nu); // agent - sim->agent->estimator_enabled = absl::GetFlag(FLAGS_estimator_enabled); - sim->agent->Initialize(m); + if (mplan == nullptr) sim->agent->Initialize(m, mest); sim->agent->Allocate(); sim->agent->Reset(); sim->agent->PlotInitialize(); diff --git a/mjpc/task.h b/mjpc/task.h index 5f8d0cdc5..62b1b1019 100644 --- a/mjpc/task.h +++ b/mjpc/task.h @@ -123,6 +123,8 @@ class Task { virtual std::string Name() const = 0; virtual std::string XmlPath() const = 0; + virtual std::string PlannerXmlPath() const { return ""; }; + virtual std::string EstimatorXmlPath() const { return ""; }; // mode int mode; diff --git a/mjpc/tasks/CMakeLists.txt b/mjpc/tasks/CMakeLists.txt index 3cbfa14a7..5ee262c86 100644 --- a/mjpc/tasks/CMakeLists.txt +++ b/mjpc/tasks/CMakeLists.txt @@ -46,6 +46,9 @@ add_custom_target( COMMAND patch -o ${CMAKE_CURRENT_BINARY_DIR}/particle/particle_modified.xml ${CMAKE_CURRENT_BINARY_DIR}/particle/particle.xml <${CMAKE_CURRENT_SOURCE_DIR}/particle/particle.xml.patch + COMMAND patch -o ${CMAKE_CURRENT_BINARY_DIR}/particle/particle_modified_red.xml + ${CMAKE_CURRENT_BINARY_DIR}/particle/particle.xml + <${CMAKE_CURRENT_SOURCE_DIR}/particle/particle_red.xml.patch # swimmer COMMAND ${CMAKE_COMMAND} -E copy ${dm_control_SOURCE_DIR}/dm_control/suite/swimmer.xml diff --git a/mjpc/tasks/particle/particle.cc b/mjpc/tasks/particle/particle.cc index 58c157b13..ff9d657c0 100644 --- a/mjpc/tasks/particle/particle.cc +++ b/mjpc/tasks/particle/particle.cc @@ -22,7 +22,10 @@ namespace mjpc { std::string Particle::XmlPath() const { - return GetModelPath("particle/task_timevarying.xml"); + return GetModelPath("particle/task_timevarying_task.xml"); +} +std::string Particle::PlannerXmlPath() const { + return GetModelPath("particle/task_timevarying_plan.xml"); } std::string Particle::Name() const { return "Particle"; } diff --git a/mjpc/tasks/particle/particle.h b/mjpc/tasks/particle/particle.h index e81a1995c..284828f61 100644 --- a/mjpc/tasks/particle/particle.h +++ b/mjpc/tasks/particle/particle.h @@ -25,6 +25,7 @@ class Particle : public Task { public: std::string Name() const override; std::string XmlPath() const override; + std::string PlannerXmlPath() const override; class ResidualFn : public mjpc::BaseResidualFn { public: explicit ResidualFn(const Particle* task) : mjpc::BaseResidualFn(task) {} diff --git a/mjpc/tasks/particle/particle.xml b/mjpc/tasks/particle/particle.xml new file mode 100644 index 000000000..c447cf614 --- /dev/null +++ b/mjpc/tasks/particle/particle.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mjpc/tasks/particle/particle_red.xml.patch b/mjpc/tasks/particle/particle_red.xml.patch new file mode 100644 index 000000000..cdc77c8e4 --- /dev/null +++ b/mjpc/tasks/particle/particle_red.xml.patch @@ -0,0 +1,57 @@ +diff --git a/particle_modified.xml b/particle_modified.xml +--- a/particle_modified.xml ++++ b/particle_modified.xml +@@ -1,9 +1,5 @@ +- +- +- +- +- +-