Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
MissionRuntimeOptions,
MissionStepSetData,
MissionStepPoseWaypoint,
MissionStepIf,
Pose,
)
from inorbit_edge_executor.worker_pool import WorkerPool
Expand Down Expand Up @@ -121,8 +122,11 @@ async def main():
"state": "starting",
"label": "Example inorbit_edge_executor mission",
"tasks": [
{"taskId": "step 0", "label": "Step 0"},
{"taskId": "step 1", "label": "Step 1"},
{"taskId": "step 2", "label": "Step 2"},
{"taskId": "step 3", "label": "Then waypoint A"},
{"taskId": "step 4", "label": "Else waypoint B"},
],
},
)
Expand All @@ -135,13 +139,66 @@ async def main():
definition=MissionDefinition(
label="A mission definition",
steps=[
MissionStepSetData(
label="set some data", completeTask="step 1", data={"key": "value"}
MissionStepSetData.model_validate(
{
"label": "set some data",
"completeTask": "step 0",
"data": {"key": "value"},
}
),
MissionStepPoseWaypoint(
label="go to waypoint",
completeTask="step 2",
waypoint=Pose(x=0, y=0, theta=0, frameId="map", waypointId="wp1"),
MissionStepSetData.model_validate(
{
"label": "set more data",
"completeTask": "step 1",
"data": {"key2": "value2"},
}
),
MissionStepPoseWaypoint.model_validate(
{
"label": "go to waypoint",
"completeTask": "step 2",
"waypoint": {
"x": 0,
"y": 0,
"theta": 0,
"frameId": "map",
"waypointId": "wp1",
},
}
),
MissionStepIf.model_validate(
{
"label": "if",
"if": {
"expression": "0 > 1",
"then": [
{
"label": "go to waypoint A",
"completeTask": "step 3",
"waypoint": {
"x": 0,
"y": 0,
"theta": 0,
"frameId": "map",
"waypointId": "wpA",
},
},
],
"else": [
{
"label": "go to waypoint B",
"completeTask": "step 4",
"waypoint": {
"x": 0,
"y": 0,
"theta": 0,
"frameId": "map",
"waypointId": "wpB",
},
},
],
},
}
),
],
),
Expand Down
238 changes: 217 additions & 21 deletions inorbit_edge_executor/behavior_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
from typing import Dict
from typing import List
from typing import Union
from typing import Callable

from async_timeout import timeout

from .datatypes import MissionRuntimeOptions
from .datatypes import MissionRuntimeSharedMemory
from .datatypes import MissionStep
from .datatypes import MissionStepPoseWaypoint
from .datatypes import MissionStepRunAction
from .datatypes import MissionStepSetData
Expand Down Expand Up @@ -686,6 +688,119 @@ def from_object(cls, context, expression, target=None, **kwargs):
return WaitExpressionNode(context, expression, target, **kwargs)


class IfNode(BehaviorTree):
"""
Node that evaluates an expression once and conditionally executes either a "then" or "else"
branch based on the result. The expression is evaluated through REST APIs, normally in the same
robot that executes the mission.
"""

def __init__(
self,
context: BehaviorTreeBuilderContext,
expression: str,
then_branch: BehaviorTree,
else_branch: BehaviorTree = None,
target: Target = None,
retry_wait_secs: float = 3,
**kwargs,
):
super().__init__(**kwargs)
self.expression = expression
self.retry_wait_secs = retry_wait_secs
self.then_branch = then_branch
self.else_branch = else_branch
self.target = target
if self.target is None:
self.robot = context.robot_api
else:
self.robot = context.robot_api_factory.build(self.target.robot_id)

async def _execute(self):
logger.debug(f"evaluating expression {self.expression} on {self.robot.robot_id}")
try:
result = None
max_attempts = 5
for attempt in range(1, max_attempts + 1):
try:
logger.debug(
f"Attempt {attempt}/{max_attempts} to evaluate expression: {self.expression}"
)
result = await self.robot.evaluate_expression(self.expression)
break
except Exception as e:
logger.warning(
f"Attempt {attempt} failed for expression {self.expression}: {e}"
)
if attempt < max_attempts:
await asyncio.sleep(self.retry_wait_secs)
else:
logger.error(
f"All {max_attempts} attempts failed for expression {self.expression}"
)
raise
except Exception as e:
logger.error(f"Error evaluating expression {self.expression}: {e}")
raise e

if result:
logger.debug(f"expression {self.expression} == true, executing then branch")
await self.then_branch.execute()
self.state = self.then_branch.state
self.last_error = self.then_branch.last_error
else:
if self.else_branch is not None:
logger.debug(f"expression {self.expression} == false, executing else branch")
await self.else_branch.execute()
self.state = self.else_branch.state
self.last_error = self.else_branch.last_error
else:
logger.debug(f"expression {self.expression} == false, no else branch, succeeding")
# No else branch, succeed (no-op)
self.state = NODE_STATE_SUCCESS
self.last_error = ""

def reset_execution(self):
super().reset_execution()
if self.then_branch:
self.then_branch.reset_execution()
if self.else_branch:
self.else_branch.reset_execution()

def reset_handlers_execution(self):
super().reset_handlers_execution()
if self.then_branch:
self.then_branch.reset_handlers_execution()
if self.else_branch:
self.else_branch.reset_handlers_execution()

def collect_nodes(self, nodes_list: List):
super().collect_nodes(nodes_list)
if self.then_branch:
self.then_branch.collect_nodes(nodes_list)
if self.else_branch:
self.else_branch.collect_nodes(nodes_list)

def dump_object(self):
object = super().dump_object()
object["expression"] = self.expression
object["then_branch"] = self.then_branch.dump_object()
if self.else_branch is not None:
object["else_branch"] = self.else_branch.dump_object()
if self.target is not None:
object["target"] = self.target.dump_object()
object["retry_wait_secs"] = self.retry_wait_secs
return object

@classmethod
def from_object(cls, context, expression, then_branch, else_branch=None, target=None, **kwargs):
then_branch_tree = build_tree_from_object(context, then_branch)
else_branch_tree = build_tree_from_object(context, else_branch) if else_branch else None
if target is not None:
target = Target.from_object(**target)
return IfNode(context, expression, then_branch_tree, else_branch_tree, target, **kwargs)


class DummyNode(BehaviorTree):
async def _execute(self):
pass
Expand Down Expand Up @@ -983,6 +1098,11 @@ def from_object(cls, context, node_state, **kwargs):

class NodeFromStepBuilder:
def __init__(self, context: BehaviorTreeBuilderContext):
"""
Implements the visitor pattern for building behavior tree nodes from mission steps.
Args:
context: The behavior tree builder context.
"""
self.context = context
self.waypoint_distance_tolerance = WAYPOINT_DISTANCE_TOLERANCE_DEFAULT
self.waypoint_angular_tolerance = WAYPOINT_ANGULAR_TOLERANCE_DEFAULT
Expand All @@ -1001,6 +1121,27 @@ def __init__(self, context: BehaviorTreeBuilderContext):
if WAYPOINT_ANGULAR_TOLERANCE in args:
self.waypoint_angular_tolerance = float(args[WAYPOINT_ANGULAR_TOLERANCE])

def add_step_node_decorator(
self, step_decorator_fn: Callable[[MissionStep, BehaviorTree], BehaviorTree]
):
# Patch all visit_* methods so that they call the step decorator around the real core node
for attr_name in dir(self):
if attr_name.startswith("visit_") and callable(getattr(self, attr_name)):
orig_method = getattr(self, attr_name)
# Don't double-wrap if it's already wrapped (avoid recursion)
if hasattr(orig_method, "__wrapped_with_step_wrapper__"):
continue

def make_wrapped(orig_method):
def visit_method(step):
core_node = orig_method(step)
return step_decorator_fn(step, core_node)

visit_method.__wrapped_with_step_wrapper__ = True
return visit_method

setattr(self, attr_name, make_wrapped(orig_method))

def visit_wait(self, step: MissionStepWait):
return WaitNode(self.context, step.timeout_secs, label=step.label)

Expand Down Expand Up @@ -1117,7 +1258,32 @@ def visit_wait_until(self, step: MissionStepWaitUntil):
)

def visit_if(self, step: MissionStepIf):
raise NotImplementedError("visit_if not implemented")
# Build the behavior tree nodes for the then branch
then_label = f"{step.label} - then" if step.label else "then"
then_branch = BehaviorTreeSequential(label=then_label)
for then_step in step.then:
node = then_step.accept(self)
if node:
then_branch.add_node(node)
# Build the behavior tree nodes for the else branch (if it exists)
else_branch = None
if step.else_ is not None:
else_label = f"{step.label} - else" if step.label else "else"
else_branch = BehaviorTreeSequential(label=else_label)
for else_step in step.else_:
node = else_step.accept(self)
if node:
else_branch.add_node(node)
# Create the if node
if_node = IfNode(
context=self.context,
expression=step.expression,
then_branch=then_branch,
else_branch=else_branch,
target=step.target,
label=step.label,
)
return if_node


# List of accepted node types (classes). With register_accepted_node_types(),
Expand All @@ -1128,6 +1294,7 @@ def visit_if(self, step: MissionStepIf):
WaitNode,
RunActionNode,
WaitExpressionNode,
IfNode,
DummyNode,
TimeoutNode,
MissionStartNode,
Expand Down Expand Up @@ -1181,41 +1348,70 @@ def __init__(self, step_builder_factory: NodeFromStepBuilder = None, *args, **kw
step_builder_factory if step_builder_factory else NodeFromStepBuilder
)

def build_tree_for_mission(self, context: BehaviorTreeBuilderContext) -> BehaviorTree:
mission = context.mission
tree = BehaviorTreeSequential(label=f"mission {mission.id}")
tree.add_node(MissionInProgressNode(context, label="mission start"))
step_builder = self._step_builder_factory(context)

for step, ix in zip(mission.definition.steps, range(len(mission.definition.steps))):
# TODO build the right kind of behavior node
try:
node = step.accept(step_builder)
except Exception as e: # TODO
raise Exception(f"Error building step #{ix} [{step}]: {str(e)}")
# Before every step, keep robot locked
tree.add_node(LockRobotNode(context, label="lock robot"))
if step.timeout_secs is not None and type(node) not in (WaitNode, TimeoutNode):
node = TimeoutNode(step.timeout_secs, node, label=f"timeout for {step.label}")
def _build_step_decorator_for_context(
self, context: BehaviorTreeBuilderContext
) -> Callable[[MissionStep, BehaviorTree], BehaviorTreeSequential]:
def _step_decorator_fn(
step: MissionStep, core_node: BehaviorTree
) -> BehaviorTreeSequential:
"""
Wraps a step node with lock robot, timeout, and task tracking nodes.
Returns a BehaviorTreeSequential containing all necessary nodes for the step.
"""
sequential = BehaviorTreeSequential(label=step.label)

# Always add lock robot node before the step
sequential.add_node(LockRobotNode(context, label="lock robot"))

# Add task started node if complete_task is set
if step.complete_task is not None:
tree.add_node(
sequential.add_node(
TaskStartedNode(
context,
step.complete_task,
label=f"report task {step.complete_task} started",
)
)
if node:
tree.add_node(node)

# Wrap core node in TimeoutNode if timeout_secs is set and node is not already WaitNode or TimeoutNode
if step.timeout_secs is not None and type(core_node) not in (WaitNode, TimeoutNode):
core_node = TimeoutNode(
step.timeout_secs, core_node, label=f"timeout for {step.label}"
)

# Add the core node (possibly wrapped in TimeoutNode)
if core_node:
sequential.add_node(core_node)

# Add task completed node if complete_task is set
if step.complete_task is not None:
tree.add_node(
sequential.add_node(
TaskCompletedNode(
context,
step.complete_task,
label=f"report task {step.complete_task} completed",
)
)

return sequential

return _step_decorator_fn

def build_tree_for_mission(self, context: BehaviorTreeBuilderContext) -> BehaviorTree:
mission = context.mission
tree = BehaviorTreeSequential(label=f"mission {mission.id}")
tree.add_node(MissionInProgressNode(context, label="mission start"))
step_builder = self._step_builder_factory(context)
step_builder.add_step_node_decorator(self._build_step_decorator_for_context(context))

for step, ix in zip(mission.definition.steps, range(len(mission.definition.steps))):
try:
node = step.accept(step_builder)
except Exception as e: # TODO
raise Exception(f"Error building step #{ix} [{step}]: {str(e)}")
if node:
tree.add_node(node)

tree.add_node(MissionCompletedNode(context, label="mission completed"))
tree.add_node(UnlockRobotNode(context, label="unlock robot after mission completed"))
# add error handlers
Expand Down
Loading