Skip to content

Commit ab73ba7

Browse files
authored
feat: tool calling bench - manipulation tasks extenstion (#656)
1 parent 269f4f6 commit ab73ba7

File tree

5 files changed

+926
-128
lines changed

5 files changed

+926
-128
lines changed

src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ class MockGetObjectPositionsTool(GetObjectPositionsTool):
184184
mock_objects: dict[str, List[Point]]
185185

186186
def _run(self, object_name: str) -> str:
187-
"""Method that returns a mock message with the object positions if the object_name is present in the mock_objects dictionary.
187+
"""Method that returns a mock message with the object positions
188+
if the object_name is present in the mock_objects dictionary.
188189
189190
Parameters
190191
----------

src/rai_bench/rai_bench/tool_calling_agent/predefined/manipulation_tasks.py

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Literal
15+
from typing import Any, Dict, List, Literal
1616

1717
from rai.tools.ros2 import MoveToPointToolInput
1818
from rai.types import Point
@@ -21,33 +21,25 @@
2121
Task,
2222
TaskArgs,
2323
)
24-
from rai_bench.tool_calling_agent.subtasks import (
25-
CheckArgsToolCallSubTask,
26-
)
2724
from rai_bench.tool_calling_agent.tasks.manipulation import (
25+
AlignTwoObjectsTask,
26+
GetObjectPositionsTask,
27+
GrabExistingObjectTask,
28+
MoveExistingObjectFrontTask,
29+
MoveExistingObjectLeftTask,
2830
MoveToPointTask,
2931
)
30-
from rai_bench.tool_calling_agent.validators import (
31-
OrderedCallsValidator,
32-
)
3332

34-
########## SUBTASKS #################################################################
35-
move_to_point_subtask_grab = CheckArgsToolCallSubTask(
36-
expected_tool_name="move_to_point",
37-
expected_args={"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"},
38-
)
39-
move_to_point_subtask_drop = CheckArgsToolCallSubTask(
40-
expected_tool_name="move_to_point",
41-
expected_args={"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"},
42-
)
33+
BANANA_POSITION = Point(x=0.1, y=0.2, z=0.3)
34+
BANANA_POSITION_2 = Point(x=0.4, y=0.5, z=0.6)
35+
CUBE_POSITION = Point(x=0.7, y=0.8, z=0.9)
4336

44-
######### VALIDATORS #########################################################################################
45-
move_to_point_ord_val_grab = OrderedCallsValidator(
46-
subtasks=[move_to_point_subtask_grab]
47-
)
48-
move_to_point_ord_val_drop = OrderedCallsValidator(
49-
subtasks=[move_to_point_subtask_drop]
50-
)
37+
BANANA_OBJECT = "banana"
38+
CUBE_OBJECT = "cube"
39+
APPLE_OBJECT = "apple"
40+
41+
MOVE_TO_GRAB_COORDS: Dict[str, Any] = {"x": 1.0, "y": 2.0, "z": 3.0, "task": "grab"}
42+
MOVE_TO_DROP_COORDS: Dict[str, Any] = {"x": 1.2, "y": 2.3, "z": 3.4, "task": "drop"}
5143

5244

5345
def get_manipulation_tasks(
@@ -69,9 +61,15 @@ def get_manipulation_tasks(
6961
tasks: List[Task] = []
7062

7163
objects = {
72-
"banana": [Point(x=0.1, y=0.2, z=0.3), Point(x=0.4, y=0.5, z=0.6)],
73-
"cube": [Point(x=0.7, y=0.8, z=0.9)],
64+
BANANA_OBJECT: [BANANA_POSITION],
65+
CUBE_OBJECT: [CUBE_POSITION],
66+
}
67+
68+
objects_with_multiple_bananas = {
69+
BANANA_OBJECT: [BANANA_POSITION, BANANA_POSITION_2],
70+
CUBE_OBJECT: [CUBE_POSITION],
7471
}
72+
7573
for extra_calls in extra_tool_calls:
7674
for detail in prompt_detail:
7775
for shots in n_shots:
@@ -80,24 +78,58 @@ def get_manipulation_tasks(
8078
prompt_detail=detail,
8179
examples_in_system_prompt=shots,
8280
)
81+
8382
tasks.extend(
8483
[
8584
MoveToPointTask(
8685
objects=objects,
8786
move_to_tool_input=MoveToPointToolInput(
8887
x=1.0, y=2.0, z=3.0, task="grab"
8988
),
90-
validators=[move_to_point_ord_val_grab],
9189
task_args=task_args,
9290
),
9391
MoveToPointTask(
9492
objects=objects,
9593
move_to_tool_input=MoveToPointToolInput(
9694
x=1.2, y=2.3, z=3.4, task="drop"
9795
),
98-
validators=[move_to_point_ord_val_drop],
9996
task_args=task_args,
10097
),
98+
GetObjectPositionsTask(
99+
objects=objects_with_multiple_bananas,
100+
task_args=task_args,
101+
),
102+
GrabExistingObjectTask(
103+
objects=objects,
104+
object_to_grab=CUBE_OBJECT,
105+
task_args=task_args,
106+
),
107+
GrabExistingObjectTask(
108+
objects=objects,
109+
object_to_grab=BANANA_OBJECT,
110+
task_args=task_args,
111+
),
112+
MoveExistingObjectLeftTask(
113+
objects=objects,
114+
object_to_grab=CUBE_OBJECT,
115+
task_args=task_args,
116+
),
117+
MoveExistingObjectLeftTask(
118+
objects=objects,
119+
object_to_grab=BANANA_OBJECT,
120+
task_args=task_args,
121+
),
122+
MoveExistingObjectFrontTask(
123+
objects=objects,
124+
object_to_grab=CUBE_OBJECT,
125+
task_args=task_args,
126+
),
127+
MoveExistingObjectFrontTask(
128+
objects=objects,
129+
object_to_grab=BANANA_OBJECT,
130+
task_args=task_args,
131+
),
132+
AlignTwoObjectsTask(objects=objects, task_args=task_args),
101133
]
102134
)
103135

0 commit comments

Comments
 (0)