diff --git a/core/include/moveit/task_constructor/stage.h b/core/include/moveit/task_constructor/stage.h index 82d01c765..25148dba9 100644 --- a/core/include/moveit/task_constructor/stage.h +++ b/core/include/moveit/task_constructor/stage.h @@ -211,6 +211,11 @@ class Stage return properties().get("trajectory_execution_info"); } + /// If true, a failure of this stage will cause immediate preemption of the task + void setPreemptOnFailure(bool preempt_on_failure) { setProperty("preempt_on_failure", preempt_on_failure); } + /// Get whether this stage is configured to preempt the task on failure + bool preemptOnFailure() const { return properties().get("preempt_on_failure"); } + /// forwarding of properties between interface states void forwardProperties(const InterfaceState& source, InterfaceState& dest); std::set forwardedProperties() const { @@ -257,6 +262,9 @@ class Stage double getTotalComputeTime() const; + /// Request preemption of task execution + void requestPreemption(); + protected: /// Stage can only be instantiated through derived classes Stage(StagePrivate* impl); diff --git a/core/src/container.cpp b/core/src/container.cpp index c1fc23f11..277189be8 100644 --- a/core/src/container.cpp +++ b/core/src/container.cpp @@ -83,6 +83,16 @@ printChildrenInterfaces(const ContainerBasePrivate& container, bool success, con } } +// Check if preemptOnFailure is set for the given stage +static void checkPreemptOnFailure(const Stage& child) { + if (child.preemptOnFailure()) { + RCLCPP_WARN_STREAM( + rclcpp::get_logger("Container"), + fmt::format("'{}' failed and has preempt_on_failure set to true. Preempting task planning.", child.name())); + const_cast(child).requestPreemption(); + } +} + ContainerBasePrivate::ContainerBasePrivate(ContainerBase* me, const std::string& name) : StagePrivate(me, name) , required_interface_(UNKNOWN) @@ -232,6 +242,8 @@ inline void updateStatePrios(const InterfaceState& s, const InterfaceState::Prio } void ContainerBasePrivate::onNewFailure(const Stage& child, const InterfaceState* from, const InterfaceState* to) { + checkPreemptOnFailure(child); + if (!static_cast(me_)->pruning()) return; @@ -924,6 +936,8 @@ void FallbacksPrivate::onNewSolution(const SolutionBase& s) { } void FallbacksPrivate::onNewFailure(const Stage& child, const InterfaceState* /*from*/, const InterfaceState* /*to*/) { + checkPreemptOnFailure(child); + // This override is deliberately empty. // The method prunes solution paths when a child failed to find a valid solution for it, // but in Fallbacks the next child might still yield a successful solution @@ -1074,6 +1088,8 @@ void FallbacksPrivateConnect::compute() { } void FallbacksPrivateConnect::onNewFailure(const Stage& child, const InterfaceState* from, const InterfaceState* to) { + checkPreemptOnFailure(child); + // expect failure to be reported from active child assert(active_ != children().end() && active_->get() == &child); (void)child; diff --git a/core/src/stage.cpp b/core/src/stage.cpp index 789da4651..e73a2cb23 100644 --- a/core/src/stage.cpp +++ b/core/src/stage.cpp @@ -313,7 +313,8 @@ Stage::Stage(StagePrivate* impl) : pimpl_(impl) { p.declare("marker_ns", name(), "marker namespace"); p.declare("trajectory_execution_info", TrajectoryExecutionInfo(), "settings used when executing the trajectory"); - + p.declare("preempt_on_failure", false, + "if true, a failure of this stage will cause immediate preemption of the task"); p.declare>("forwarded_properties", std::set(), "set of interface properties to forward"); } @@ -448,6 +449,12 @@ double Stage::getTotalComputeTime() const { return pimpl()->total_compute_time_.count(); } +void Stage::requestPreemption() { + if (pimpl()->preempt_requested_ != nullptr) { + const_cast*>(pimpl()->preempt_requested_)->store(true); + } +} + void StagePrivate::composePropertyErrorMsg(const std::string& property_name, std::ostream& os) { if (property_name.empty()) return;