12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import List , Literal
15
+ from typing import Any , Dict , List , Literal
16
16
17
17
from rai .tools .ros2 import MoveToPointToolInput
18
18
from rai .types import Point
21
21
Task ,
22
22
TaskArgs ,
23
23
)
24
- from rai_bench .tool_calling_agent .subtasks import (
25
- CheckArgsToolCallSubTask ,
26
- )
27
24
from rai_bench .tool_calling_agent .tasks .manipulation import (
25
+ AlignTwoObjectsTask ,
26
+ GetObjectPositionsTask ,
27
+ GrabExistingObjectTask ,
28
+ MoveExistingObjectFrontTask ,
29
+ MoveExistingObjectLeftTask ,
28
30
MoveToPointTask ,
29
31
)
30
- from rai_bench .tool_calling_agent .validators import (
31
- OrderedCallsValidator ,
32
- )
33
32
34
- ########## SUBTASKS #################################################################
35
- move_to_point_subtask_grab = CheckArgsToolCallSubTask (
36
- expected_tool_name = "move_to_point" ,
37
- expected_args = {"x" : 1.0 , "y" : 2.0 , "z" : 3.0 , "task" : "grab" },
38
- )
39
- move_to_point_subtask_drop = CheckArgsToolCallSubTask (
40
- expected_tool_name = "move_to_point" ,
41
- expected_args = {"x" : 1.2 , "y" : 2.3 , "z" : 3.4 , "task" : "drop" },
42
- )
33
+ BANANA_POSITION = Point (x = 0.1 , y = 0.2 , z = 0.3 )
34
+ BANANA_POSITION_2 = Point (x = 0.4 , y = 0.5 , z = 0.6 )
35
+ CUBE_POSITION = Point (x = 0.7 , y = 0.8 , z = 0.9 )
43
36
44
- ######### VALIDATORS #########################################################################################
45
- move_to_point_ord_val_grab = OrderedCallsValidator (
46
- subtasks = [move_to_point_subtask_grab ]
47
- )
48
- move_to_point_ord_val_drop = OrderedCallsValidator (
49
- subtasks = [move_to_point_subtask_drop ]
50
- )
37
+ BANANA_OBJECT = "banana"
38
+ CUBE_OBJECT = "cube"
39
+ APPLE_OBJECT = "apple"
40
+
41
+ MOVE_TO_GRAB_COORDS : Dict [str , Any ] = {"x" : 1.0 , "y" : 2.0 , "z" : 3.0 , "task" : "grab" }
42
+ MOVE_TO_DROP_COORDS : Dict [str , Any ] = {"x" : 1.2 , "y" : 2.3 , "z" : 3.4 , "task" : "drop" }
51
43
52
44
53
45
def get_manipulation_tasks (
@@ -69,9 +61,15 @@ def get_manipulation_tasks(
69
61
tasks : List [Task ] = []
70
62
71
63
objects = {
72
- "banana" : [Point (x = 0.1 , y = 0.2 , z = 0.3 ), Point (x = 0.4 , y = 0.5 , z = 0.6 )],
73
- "cube" : [Point (x = 0.7 , y = 0.8 , z = 0.9 )],
64
+ BANANA_OBJECT : [BANANA_POSITION ],
65
+ CUBE_OBJECT : [CUBE_POSITION ],
66
+ }
67
+
68
+ objects_with_multiple_bananas = {
69
+ BANANA_OBJECT : [BANANA_POSITION , BANANA_POSITION_2 ],
70
+ CUBE_OBJECT : [CUBE_POSITION ],
74
71
}
72
+
75
73
for extra_calls in extra_tool_calls :
76
74
for detail in prompt_detail :
77
75
for shots in n_shots :
@@ -80,24 +78,58 @@ def get_manipulation_tasks(
80
78
prompt_detail = detail ,
81
79
examples_in_system_prompt = shots ,
82
80
)
81
+
83
82
tasks .extend (
84
83
[
85
84
MoveToPointTask (
86
85
objects = objects ,
87
86
move_to_tool_input = MoveToPointToolInput (
88
87
x = 1.0 , y = 2.0 , z = 3.0 , task = "grab"
89
88
),
90
- validators = [move_to_point_ord_val_grab ],
91
89
task_args = task_args ,
92
90
),
93
91
MoveToPointTask (
94
92
objects = objects ,
95
93
move_to_tool_input = MoveToPointToolInput (
96
94
x = 1.2 , y = 2.3 , z = 3.4 , task = "drop"
97
95
),
98
- validators = [move_to_point_ord_val_drop ],
99
96
task_args = task_args ,
100
97
),
98
+ GetObjectPositionsTask (
99
+ objects = objects_with_multiple_bananas ,
100
+ task_args = task_args ,
101
+ ),
102
+ GrabExistingObjectTask (
103
+ objects = objects ,
104
+ object_to_grab = CUBE_OBJECT ,
105
+ task_args = task_args ,
106
+ ),
107
+ GrabExistingObjectTask (
108
+ objects = objects ,
109
+ object_to_grab = BANANA_OBJECT ,
110
+ task_args = task_args ,
111
+ ),
112
+ MoveExistingObjectLeftTask (
113
+ objects = objects ,
114
+ object_to_grab = CUBE_OBJECT ,
115
+ task_args = task_args ,
116
+ ),
117
+ MoveExistingObjectLeftTask (
118
+ objects = objects ,
119
+ object_to_grab = BANANA_OBJECT ,
120
+ task_args = task_args ,
121
+ ),
122
+ MoveExistingObjectFrontTask (
123
+ objects = objects ,
124
+ object_to_grab = CUBE_OBJECT ,
125
+ task_args = task_args ,
126
+ ),
127
+ MoveExistingObjectFrontTask (
128
+ objects = objects ,
129
+ object_to_grab = BANANA_OBJECT ,
130
+ task_args = task_args ,
131
+ ),
132
+ AlignTwoObjectsTask (objects = objects , task_args = task_args ),
101
133
]
102
134
)
103
135
0 commit comments