Skip to content

Commit e193759

Browse files
authored
Add Conditional Execution Support with IfNode (#19)
* add if step * black * remove todo * fix step alias * fix tests * added not yet implemented visit_if in NodeFromStepBuilder * pre working * implement if node * fix merge * black * add retry logic * optimize tests * refactor to keep compatibility * working ! * working ! * remove print * black * black * set state on no else
1 parent 4062843 commit e193759

File tree

6 files changed

+902
-44
lines changed

6 files changed

+902
-44
lines changed

example.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
MissionRuntimeOptions,
3131
MissionStepSetData,
3232
MissionStepPoseWaypoint,
33+
MissionStepIf,
3334
Pose,
3435
)
3536
from inorbit_edge_executor.worker_pool import WorkerPool
@@ -121,8 +122,11 @@ async def main():
121122
"state": "starting",
122123
"label": "Example inorbit_edge_executor mission",
123124
"tasks": [
125+
{"taskId": "step 0", "label": "Step 0"},
124126
{"taskId": "step 1", "label": "Step 1"},
125127
{"taskId": "step 2", "label": "Step 2"},
128+
{"taskId": "step 3", "label": "Then waypoint A"},
129+
{"taskId": "step 4", "label": "Else waypoint B"},
126130
],
127131
},
128132
)
@@ -135,13 +139,66 @@ async def main():
135139
definition=MissionDefinition(
136140
label="A mission definition",
137141
steps=[
138-
MissionStepSetData(
139-
label="set some data", completeTask="step 1", data={"key": "value"}
142+
MissionStepSetData.model_validate(
143+
{
144+
"label": "set some data",
145+
"completeTask": "step 0",
146+
"data": {"key": "value"},
147+
}
140148
),
141-
MissionStepPoseWaypoint(
142-
label="go to waypoint",
143-
completeTask="step 2",
144-
waypoint=Pose(x=0, y=0, theta=0, frameId="map", waypointId="wp1"),
149+
MissionStepSetData.model_validate(
150+
{
151+
"label": "set more data",
152+
"completeTask": "step 1",
153+
"data": {"key2": "value2"},
154+
}
155+
),
156+
MissionStepPoseWaypoint.model_validate(
157+
{
158+
"label": "go to waypoint",
159+
"completeTask": "step 2",
160+
"waypoint": {
161+
"x": 0,
162+
"y": 0,
163+
"theta": 0,
164+
"frameId": "map",
165+
"waypointId": "wp1",
166+
},
167+
}
168+
),
169+
MissionStepIf.model_validate(
170+
{
171+
"label": "if",
172+
"if": {
173+
"expression": "0 > 1",
174+
"then": [
175+
{
176+
"label": "go to waypoint A",
177+
"completeTask": "step 3",
178+
"waypoint": {
179+
"x": 0,
180+
"y": 0,
181+
"theta": 0,
182+
"frameId": "map",
183+
"waypointId": "wpA",
184+
},
185+
},
186+
],
187+
"else": [
188+
{
189+
"label": "go to waypoint B",
190+
"completeTask": "step 4",
191+
"waypoint": {
192+
"x": 0,
193+
"y": 0,
194+
"theta": 0,
195+
"frameId": "map",
196+
"waypointId": "wpB",
197+
},
198+
},
199+
],
200+
},
201+
}
145202
),
146203
],
147204
),

inorbit_edge_executor/behavior_tree.py

Lines changed: 217 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
from typing import Dict
2727
from typing import List
2828
from typing import Union
29+
from typing import Callable
2930

3031
from async_timeout import timeout
3132

3233
from .datatypes import MissionRuntimeOptions
3334
from .datatypes import MissionRuntimeSharedMemory
35+
from .datatypes import MissionStep
3436
from .datatypes import MissionStepPoseWaypoint
3537
from .datatypes import MissionStepRunAction
3638
from .datatypes import MissionStepSetData
@@ -686,6 +688,119 @@ def from_object(cls, context, expression, target=None, **kwargs):
686688
return WaitExpressionNode(context, expression, target, **kwargs)
687689

688690

691+
class IfNode(BehaviorTree):
692+
"""
693+
Node that evaluates an expression once and conditionally executes either a "then" or "else"
694+
branch based on the result. The expression is evaluated through REST APIs, normally in the same
695+
robot that executes the mission.
696+
"""
697+
698+
def __init__(
699+
self,
700+
context: BehaviorTreeBuilderContext,
701+
expression: str,
702+
then_branch: BehaviorTree,
703+
else_branch: BehaviorTree = None,
704+
target: Target = None,
705+
retry_wait_secs: float = 3,
706+
**kwargs,
707+
):
708+
super().__init__(**kwargs)
709+
self.expression = expression
710+
self.retry_wait_secs = retry_wait_secs
711+
self.then_branch = then_branch
712+
self.else_branch = else_branch
713+
self.target = target
714+
if self.target is None:
715+
self.robot = context.robot_api
716+
else:
717+
self.robot = context.robot_api_factory.build(self.target.robot_id)
718+
719+
async def _execute(self):
720+
logger.debug(f"evaluating expression {self.expression} on {self.robot.robot_id}")
721+
try:
722+
result = None
723+
max_attempts = 5
724+
for attempt in range(1, max_attempts + 1):
725+
try:
726+
logger.debug(
727+
f"Attempt {attempt}/{max_attempts} to evaluate expression: {self.expression}"
728+
)
729+
result = await self.robot.evaluate_expression(self.expression)
730+
break
731+
except Exception as e:
732+
logger.warning(
733+
f"Attempt {attempt} failed for expression {self.expression}: {e}"
734+
)
735+
if attempt < max_attempts:
736+
await asyncio.sleep(self.retry_wait_secs)
737+
else:
738+
logger.error(
739+
f"All {max_attempts} attempts failed for expression {self.expression}"
740+
)
741+
raise
742+
except Exception as e:
743+
logger.error(f"Error evaluating expression {self.expression}: {e}")
744+
raise e
745+
746+
if result:
747+
logger.debug(f"expression {self.expression} == true, executing then branch")
748+
await self.then_branch.execute()
749+
self.state = self.then_branch.state
750+
self.last_error = self.then_branch.last_error
751+
else:
752+
if self.else_branch is not None:
753+
logger.debug(f"expression {self.expression} == false, executing else branch")
754+
await self.else_branch.execute()
755+
self.state = self.else_branch.state
756+
self.last_error = self.else_branch.last_error
757+
else:
758+
logger.debug(f"expression {self.expression} == false, no else branch, succeeding")
759+
# No else branch, succeed (no-op)
760+
self.state = NODE_STATE_SUCCESS
761+
self.last_error = ""
762+
763+
def reset_execution(self):
764+
super().reset_execution()
765+
if self.then_branch:
766+
self.then_branch.reset_execution()
767+
if self.else_branch:
768+
self.else_branch.reset_execution()
769+
770+
def reset_handlers_execution(self):
771+
super().reset_handlers_execution()
772+
if self.then_branch:
773+
self.then_branch.reset_handlers_execution()
774+
if self.else_branch:
775+
self.else_branch.reset_handlers_execution()
776+
777+
def collect_nodes(self, nodes_list: List):
778+
super().collect_nodes(nodes_list)
779+
if self.then_branch:
780+
self.then_branch.collect_nodes(nodes_list)
781+
if self.else_branch:
782+
self.else_branch.collect_nodes(nodes_list)
783+
784+
def dump_object(self):
785+
object = super().dump_object()
786+
object["expression"] = self.expression
787+
object["then_branch"] = self.then_branch.dump_object()
788+
if self.else_branch is not None:
789+
object["else_branch"] = self.else_branch.dump_object()
790+
if self.target is not None:
791+
object["target"] = self.target.dump_object()
792+
object["retry_wait_secs"] = self.retry_wait_secs
793+
return object
794+
795+
@classmethod
796+
def from_object(cls, context, expression, then_branch, else_branch=None, target=None, **kwargs):
797+
then_branch_tree = build_tree_from_object(context, then_branch)
798+
else_branch_tree = build_tree_from_object(context, else_branch) if else_branch else None
799+
if target is not None:
800+
target = Target.from_object(**target)
801+
return IfNode(context, expression, then_branch_tree, else_branch_tree, target, **kwargs)
802+
803+
689804
class DummyNode(BehaviorTree):
690805
async def _execute(self):
691806
pass
@@ -983,6 +1098,11 @@ def from_object(cls, context, node_state, **kwargs):
9831098

9841099
class NodeFromStepBuilder:
9851100
def __init__(self, context: BehaviorTreeBuilderContext):
1101+
"""
1102+
Implements the visitor pattern for building behavior tree nodes from mission steps.
1103+
Args:
1104+
context: The behavior tree builder context.
1105+
"""
9861106
self.context = context
9871107
self.waypoint_distance_tolerance = WAYPOINT_DISTANCE_TOLERANCE_DEFAULT
9881108
self.waypoint_angular_tolerance = WAYPOINT_ANGULAR_TOLERANCE_DEFAULT
@@ -1001,6 +1121,27 @@ def __init__(self, context: BehaviorTreeBuilderContext):
10011121
if WAYPOINT_ANGULAR_TOLERANCE in args:
10021122
self.waypoint_angular_tolerance = float(args[WAYPOINT_ANGULAR_TOLERANCE])
10031123

1124+
def add_step_node_decorator(
1125+
self, step_decorator_fn: Callable[[MissionStep, BehaviorTree], BehaviorTree]
1126+
):
1127+
# Patch all visit_* methods so that they call the step decorator around the real core node
1128+
for attr_name in dir(self):
1129+
if attr_name.startswith("visit_") and callable(getattr(self, attr_name)):
1130+
orig_method = getattr(self, attr_name)
1131+
# Don't double-wrap if it's already wrapped (avoid recursion)
1132+
if hasattr(orig_method, "__wrapped_with_step_wrapper__"):
1133+
continue
1134+
1135+
def make_wrapped(orig_method):
1136+
def visit_method(step):
1137+
core_node = orig_method(step)
1138+
return step_decorator_fn(step, core_node)
1139+
1140+
visit_method.__wrapped_with_step_wrapper__ = True
1141+
return visit_method
1142+
1143+
setattr(self, attr_name, make_wrapped(orig_method))
1144+
10041145
def visit_wait(self, step: MissionStepWait):
10051146
return WaitNode(self.context, step.timeout_secs, label=step.label)
10061147

@@ -1117,7 +1258,32 @@ def visit_wait_until(self, step: MissionStepWaitUntil):
11171258
)
11181259

11191260
def visit_if(self, step: MissionStepIf):
1120-
raise NotImplementedError("visit_if not implemented")
1261+
# Build the behavior tree nodes for the then branch
1262+
then_label = f"{step.label} - then" if step.label else "then"
1263+
then_branch = BehaviorTreeSequential(label=then_label)
1264+
for then_step in step.then:
1265+
node = then_step.accept(self)
1266+
if node:
1267+
then_branch.add_node(node)
1268+
# Build the behavior tree nodes for the else branch (if it exists)
1269+
else_branch = None
1270+
if step.else_ is not None:
1271+
else_label = f"{step.label} - else" if step.label else "else"
1272+
else_branch = BehaviorTreeSequential(label=else_label)
1273+
for else_step in step.else_:
1274+
node = else_step.accept(self)
1275+
if node:
1276+
else_branch.add_node(node)
1277+
# Create the if node
1278+
if_node = IfNode(
1279+
context=self.context,
1280+
expression=step.expression,
1281+
then_branch=then_branch,
1282+
else_branch=else_branch,
1283+
target=step.target,
1284+
label=step.label,
1285+
)
1286+
return if_node
11211287

11221288

11231289
# List of accepted node types (classes). With register_accepted_node_types(),
@@ -1128,6 +1294,7 @@ def visit_if(self, step: MissionStepIf):
11281294
WaitNode,
11291295
RunActionNode,
11301296
WaitExpressionNode,
1297+
IfNode,
11311298
DummyNode,
11321299
TimeoutNode,
11331300
MissionStartNode,
@@ -1181,41 +1348,70 @@ def __init__(self, step_builder_factory: NodeFromStepBuilder = None, *args, **kw
11811348
step_builder_factory if step_builder_factory else NodeFromStepBuilder
11821349
)
11831350

1184-
def build_tree_for_mission(self, context: BehaviorTreeBuilderContext) -> BehaviorTree:
1185-
mission = context.mission
1186-
tree = BehaviorTreeSequential(label=f"mission {mission.id}")
1187-
tree.add_node(MissionInProgressNode(context, label="mission start"))
1188-
step_builder = self._step_builder_factory(context)
1189-
1190-
for step, ix in zip(mission.definition.steps, range(len(mission.definition.steps))):
1191-
# TODO build the right kind of behavior node
1192-
try:
1193-
node = step.accept(step_builder)
1194-
except Exception as e: # TODO
1195-
raise Exception(f"Error building step #{ix} [{step}]: {str(e)}")
1196-
# Before every step, keep robot locked
1197-
tree.add_node(LockRobotNode(context, label="lock robot"))
1198-
if step.timeout_secs is not None and type(node) not in (WaitNode, TimeoutNode):
1199-
node = TimeoutNode(step.timeout_secs, node, label=f"timeout for {step.label}")
1351+
def _build_step_decorator_for_context(
1352+
self, context: BehaviorTreeBuilderContext
1353+
) -> Callable[[MissionStep, BehaviorTree], BehaviorTreeSequential]:
1354+
def _step_decorator_fn(
1355+
step: MissionStep, core_node: BehaviorTree
1356+
) -> BehaviorTreeSequential:
1357+
"""
1358+
Wraps a step node with lock robot, timeout, and task tracking nodes.
1359+
Returns a BehaviorTreeSequential containing all necessary nodes for the step.
1360+
"""
1361+
sequential = BehaviorTreeSequential(label=step.label)
1362+
1363+
# Always add lock robot node before the step
1364+
sequential.add_node(LockRobotNode(context, label="lock robot"))
1365+
1366+
# Add task started node if complete_task is set
12001367
if step.complete_task is not None:
1201-
tree.add_node(
1368+
sequential.add_node(
12021369
TaskStartedNode(
12031370
context,
12041371
step.complete_task,
12051372
label=f"report task {step.complete_task} started",
12061373
)
12071374
)
1208-
if node:
1209-
tree.add_node(node)
1375+
1376+
# Wrap core node in TimeoutNode if timeout_secs is set and node is not already WaitNode or TimeoutNode
1377+
if step.timeout_secs is not None and type(core_node) not in (WaitNode, TimeoutNode):
1378+
core_node = TimeoutNode(
1379+
step.timeout_secs, core_node, label=f"timeout for {step.label}"
1380+
)
1381+
1382+
# Add the core node (possibly wrapped in TimeoutNode)
1383+
if core_node:
1384+
sequential.add_node(core_node)
1385+
1386+
# Add task completed node if complete_task is set
12101387
if step.complete_task is not None:
1211-
tree.add_node(
1388+
sequential.add_node(
12121389
TaskCompletedNode(
12131390
context,
12141391
step.complete_task,
12151392
label=f"report task {step.complete_task} completed",
12161393
)
12171394
)
12181395

1396+
return sequential
1397+
1398+
return _step_decorator_fn
1399+
1400+
def build_tree_for_mission(self, context: BehaviorTreeBuilderContext) -> BehaviorTree:
1401+
mission = context.mission
1402+
tree = BehaviorTreeSequential(label=f"mission {mission.id}")
1403+
tree.add_node(MissionInProgressNode(context, label="mission start"))
1404+
step_builder = self._step_builder_factory(context)
1405+
step_builder.add_step_node_decorator(self._build_step_decorator_for_context(context))
1406+
1407+
for step, ix in zip(mission.definition.steps, range(len(mission.definition.steps))):
1408+
try:
1409+
node = step.accept(step_builder)
1410+
except Exception as e: # TODO
1411+
raise Exception(f"Error building step #{ix} [{step}]: {str(e)}")
1412+
if node:
1413+
tree.add_node(node)
1414+
12191415
tree.add_node(MissionCompletedNode(context, label="mission completed"))
12201416
tree.add_node(UnlockRobotNode(context, label="unlock robot after mission completed"))
12211417
# add error handlers

0 commit comments

Comments
 (0)