Skip to content

Commit f5443b6

Browse files
Refactor OpenAICUAModel to streamline action handling and improve code organization
1 parent b3d409f commit f5443b6

File tree

1 file changed

+60
-211
lines changed

1 file changed

+60
-211
lines changed

src/agentlab/agents/tool_use_agent/openai_cua.py

Lines changed: 60 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,10 @@
44

55
from agentlab.llm.llm_utils import call_openai_api_with_retries
66
from agentlab.llm.response_api import (
7-
ContentItem,
8-
LLMOutput,
9-
Message,
107
MessageBuilder,
118
OpenAIResponseAPIMessageBuilder,
129
OpenAIResponseModel,
1310
OpenAIResponseModelArgs,
14-
ToolCall,
1511
ToolCalls,
1612
)
1713

@@ -28,16 +24,8 @@
2824

2925
class OpenAICUAModel(OpenAIResponseModel):
3026

31-
def _call_api(self, messages: list[Any | MessageBuilder], tool_choice="auto", **kwargs) -> dict:
32-
input = []
33-
for msg in messages:
34-
if isinstance(msg, MessageBuilder):
35-
temp = msg.prepare_message()
36-
elif isinstance(msg, ToolCalls):
37-
temp = msg.raw_calls
38-
else:
39-
raise TypeError('Unsupported message type: {}'.format(type(msg)))
40-
input.extend(temp)
27+
def _call_api(self, messages: list[ToolCalls | MessageBuilder], tool_choice="auto", **kwargs) -> dict:
28+
input = self.convert_messages_to_api_format(messages)
4129

4230
api_params: Dict[str, Any] = {
4331
"model": self.model_name,
@@ -53,14 +41,15 @@ def _call_api(self, messages: list[Any | MessageBuilder], tool_choice="auto", **
5341
cua_tool_present = any(
5442
tool.get("type") == "computer_use_preview" for tool in api_params["tools"]
5543
)
44+
# CUA requires this tool
5645
if not cua_tool_present:
5746
api_params["tools"].extend(
5847
[
5948
{
6049
"type": "computer_use_preview",
61-
"display_width": 1024,
50+
"display_width": 1024,
6251
"display_height": 768,
63-
"environment": "browser", # other possible values: "mac", "windows", "ubuntu"
52+
"environment": "browser", # TODO: Parametrize this
6453
}
6554
]
6655
)
@@ -72,212 +61,72 @@ def _call_api(self, messages: list[Any | MessageBuilder], tool_choice="auto", **
7261

7362
return response
7463

75-
def _parse_response(self, response: dict) -> dict:
76-
result = LLMOutput(
77-
raw_response=response,
78-
think="",
79-
action=None,
80-
tool_calls=ToolCalls(),
81-
)
82-
interesting_keys = ["output_text"]
83-
actions = [] # Collect all actions for multi-action support
84-
85-
for output in response.output:
86-
if output.type in "computer_call":
87-
# Mapping CUA action space to bgym coord action space.
88-
bgym_fn, bgym_fn_args, action_str = (
89-
self.cua_action_to_bgym_action(output.action)
90-
)
91-
tool_call = ToolCall(
92-
name=bgym_fn,
93-
arguments=bgym_fn_args,
94-
raw_call=output,
95-
)
96-
result.tool_calls.add_tool_call(tool_call)
97-
actions.append(action_str)
98-
99-
elif output.type == "function_call":
100-
arguments = json.loads(output.arguments)
101-
func_args_str = ", ".join(
102-
[
103-
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
104-
for k, v in arguments.items()
105-
]
106-
)
107-
action_str = f"{output.name}({func_args_str})"
108-
tool_call = ToolCall(
109-
name=output.name,
110-
arguments=arguments,
111-
raw_call=output,
112-
)
113-
result.tool_calls.add_tool_call(tool_call)
114-
if tool_call.is_bgym_action():
115-
actions.append(action_str)
116-
117-
elif output.type == "reasoning":
118-
if len(output.summary) > 0:
119-
result.think += output.summary[0].text + "\n"
120-
121-
elif output.type == "message" and output.content:
122-
result.think += output.content[0].text + "\n"
123-
124-
result.action = actions
125-
result.tool_calls.raw_calls = response.output
126-
127-
for key in interesting_keys:
128-
if key_content := getattr(output, "output_text", None) is not None:
129-
result.think += f"<{key}>{key_content}</{key}>"
130-
return result
131-
132-
@staticmethod
133-
def cua_action_to_bgym_action(action) -> str:
64+
def cua_action_to_env_tool_name_and_args(self, action) -> str:
13465
"""
13566
Given a computer action (e.g., click, double_click, scroll, etc.),
13667
convert it to a text description.
13768
"""
69+
#TODO: #Provide an alternate implementation for OS-World.
13870

13971
action_type = action.type
14072

14173
try:
142-
match action_type:
143-
144-
case "click":
145-
x, y = action.x, action.y
146-
button = action.button
147-
print(f"Action: click at ({x}, {y}) with button '{button}'")
148-
# Not handling things like middle click, etc.
149-
if button != "left" and button != "right":
150-
button = "left"
151-
action_str = f"mouse_click({x}, {y}, button='{button}')"
152-
(
153-
bgym_fn,
154-
bgym_fn_args,
155-
) = "mouse_click", {"x": x, "y": y, "button": button}
156-
157-
case "scroll":
158-
x, y = action.x, action.y
159-
scroll_x, scroll_y = action.scroll_x, action.scroll_y
160-
action_str = f"scroll_at({x}, {y}, {scroll_x}, {scroll_y})"
161-
bgym_fn, bgym_fn_args = "scroll_at", {
162-
"x": x,
163-
"y": y,
164-
"scroll_x": scroll_x,
165-
"scroll_y": scroll_y,
166-
}
167-
168-
case "keypress":
169-
keys = action.keys
170-
for k in keys:
171-
print(f"Action: keypress '{k}'")
172-
# A simple mapping for common keys; expand as needed.
173-
if k.lower() == "enter":
174-
action_str = "keyboard_press('Enter')"
175-
elif k.lower() == "space":
176-
action_str = "keyboard_press(' ')"
177-
else:
178-
action_str = f"keyboard_press('{k}')"
179-
180-
bgym_fn, bgym_fn_args = "keyboard_press", {"key": k}
181-
182-
case "type":
183-
text = action.text
184-
print(f"Action: type text: {text}")
185-
action_str = f"keyboard_type('{text}')"
186-
bgym_fn, bgym_fn_args = "keyboard_type", {"text": text}
187-
188-
case "wait":
189-
print("Action: wait")
190-
action_str = "noop()"
191-
bgym_fn, bgym_fn_args = "noop", {}
192-
193-
case "screenshot":
194-
# Not a valid bgym action
195-
action_str = "noop()"
196-
bgym_fn, bgym_fn_args = "noop", {}
197-
198-
case "drag":
199-
x1, y1 = action.path[0].x, action.path[0].y
200-
x2, y2 = action.path[1].x, action.path[1].y
201-
print(f"Action: drag from ({x1}, {y1}) to ({x2}, {y2})")
202-
action_str = f"mouse_drag_and_drop({x1}, {y1}, {x2}, {y2})"
203-
bgym_fn, bgym_fn_args = "mouse_drag_and_drop", {
204-
"x1": x1,
205-
"y1": y1,
206-
"x2": x2,
207-
"y2": y2,
208-
}
209-
210-
case _:
211-
raise ValueError(f"Unrecognized action type: {action_type}")
212-
213-
# Return the function name and arguments for bgym
214-
215-
return bgym_fn, bgym_fn_args, action_str
74+
action_mapping = {
75+
"click": lambda: self._handle_click_action(action),
76+
"scroll": lambda: self._handle_scroll_action(action),
77+
"keypress": lambda: self._handle_keypress_action(action),
78+
"type": lambda: self._handle_type_action(action),
79+
"wait": lambda: self._handle_wait_action(action),
80+
"screenshot": lambda: self._handle_screenshot_action(action),
81+
"drag": lambda: self._handle_drag_action(action),
82+
}
83+
84+
if action_type in action_mapping:
85+
return action_mapping[action_type]()
86+
else:
87+
raise ValueError(f"Unrecognized openAI CUA action type: {action_type}")
21688

21789
except Exception as e:
21890
print(f"Error handling action {action}: {e}")
21991

220-
221-
class OpenaAICUAMessageBuilder(OpenAIResponseAPIMessageBuilder):
222-
223-
def prepare_message(self) -> List[Message]:
224-
content = []
225-
for item in self.content:
226-
content.append(self.convert_content_to_expected_format(item))
227-
output = [{"role": self.role, "content": content}]
228-
229-
if self.role != "tool":
230-
return output
231-
else:
232-
return self.handle_tool_call()
233-
234-
def convert_content_to_expected_format(self, content: ContentItem) -> ContentItem:
235-
"""Convert the content item to the expected format for OpenAI Responses."""
236-
if "text" in content:
237-
content_type = "input_text" if self.role != "assistant" else "output_text"
238-
return {"type": content_type, "text": content["text"]}
239-
elif "image" in content:
240-
return {"type": "input_image", "image_url": content["image"]}
241-
else:
242-
raise ValueError(f"Unsupported content type: {content}")
243-
244-
def handle_tool_call(self):
245-
"""Handle the tool call response from the last raw response."""
246-
if self.responsed_tool_calls is None:
247-
raise ValueError("No tool calls found in responsed_tool_calls")
248-
249-
output = []
250-
for fn_call in self.responsed_tool_calls:
251-
call_type = fn_call.raw_call.type
252-
call_id = fn_call.raw_call.call_id
253-
call_response = fn_call.tool_response # List[ContentItem]
254-
255-
match call_type:
256-
case "function_call":
257-
# image output is not supported in function calls response.
258-
fn_call_response = {
259-
"type": "function_call_output",
260-
"call_id": call_id,
261-
"output": [
262-
self.convert_content_to_expected_format(item) for item in call_response
263-
],
264-
}
265-
output.append(fn_call_response)
266-
267-
case "computer_call":
268-
# For computer calls, use only images are expected.
269-
computer_call_output = {
270-
"type": "computer_call_output",
271-
"call_id": call_id,
272-
"output": self.convert_content_to_expected_format(call_response[0]), # list needs to be flattened
273-
}
274-
output.append(computer_call_output) # this needs to be a screenshot
275-
276-
return output
277-
278-
def mark_all_previous_msg_for_caching(self):
279-
pass
280-
92+
def _handle_click_action(self, action):
93+
x, y = action.x, action.y
94+
button = action.button
95+
if button != "left" and button != "right":
96+
button = "left"
97+
return "mouse_click", {"x": x, "y": y, "button": button}
98+
99+
def _handle_scroll_action(self, action):
100+
x, y = action.x, action.y
101+
scroll_x, scroll_y = action.scroll_x, action.scroll_y
102+
return "scroll_at", {"x": x, "y": y, "scroll_x": scroll_x, "scroll_y": scroll_y}
103+
104+
def _handle_keypress_action(self, action):
105+
keys = action.keys
106+
#TODO: Check this if is suitable for BGYM env.
107+
for k in keys:
108+
print(f"Action: keypress '{k}'")
109+
if k.lower() == "enter":
110+
key = "Enter"
111+
elif k.lower() == "space":
112+
key = " "
113+
return "keyboard_press", {"key": key}
114+
115+
def _handle_type_action(self, action):
116+
text = action.text
117+
return "keyboard_type", {"text": text}
118+
119+
def _handle_wait_action(self, action):
120+
return "noop", {}
121+
122+
def _handle_screenshot_action(self, action):
123+
return "noop", {}
124+
125+
def _handle_drag_action(self, action):
126+
x1, y1 = action.path[0].x, action.path[0].y
127+
x2, y2 = action.path[1].x, action.path[1].y
128+
print(f"Action: drag from ({x1}, {y1}) to ({x2}, {y2})")
129+
return "mouse_drag_and_drop", {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
281130

282131
@dataclass
283132
class OpenAICUAModelArgs(OpenAIResponseModelArgs):
@@ -297,7 +146,7 @@ def make_model(self, extra_kwargs=None, **kwargs):
297146
)
298147

299148
def get_message_builder(self) -> MessageBuilder:
300-
return OpenaAICUAMessageBuilder
149+
return OpenAIResponseAPIMessageBuilder
301150

302151

303152
# Default configuration for Computer Use Agent

0 commit comments

Comments
 (0)