Skip to content

Commit 7540fc1

Browse files
committed
add openai computer use agent
1 parent 3c2422e commit 7540fc1

File tree

3 files changed

+315
-0
lines changed

3 files changed

+315
-0
lines changed

src/agentlab/agents/openai_cua/__init__.py

Whitespace-only changes.
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
from dataclasses import dataclass
2+
import logging
3+
4+
from bgym import HighLevelActionSetArgs
5+
from browsergym.experiments import AbstractAgentArgs, Agent, AgentInfo
6+
from agentlab.llm.llm_utils import image_to_jpg_base64_url
7+
8+
import openai
9+
client = openai.OpenAI()
10+
11+
12+
@dataclass
13+
class OpenAIComputerUseAgentArgs(AbstractAgentArgs):
14+
"""
15+
Arguments for the OpenAI Computer Use Agent.
16+
"""
17+
agent_name: str = None
18+
model: str = "computer-use-preview"
19+
tool_type: str = "computer_use_preview"
20+
display_width: int = 1024
21+
display_height: int = 768
22+
environment: str = "browser"
23+
reasoning_summary: str = "concise"
24+
truncation: str = "auto" # Always set to "auto" for OpenAI API
25+
action_set: HighLevelActionSetArgs = None
26+
enable_safety_checks: bool = False # Optional, default to False, only use in demo mode
27+
implicit_agreement: bool = True # Whether to require explicit agreement for actions or not
28+
29+
def __post_init__(self):
30+
if self.agent_name is None:
31+
self.agent_name = "OpenAIComputerUseAgent"
32+
33+
def set_benchmark(self, benchmark, demo_mode):
34+
pass
35+
36+
def set_reproducibility_mode(self):
37+
pass
38+
39+
def make_agent(self):
40+
return OpenAIComputerUseAgent(
41+
model=self.model,
42+
tool_type=self.tool_type,
43+
display_width=self.display_width,
44+
display_height=self.display_height,
45+
environment=self.environment,
46+
reasoning_summary=self.reasoning_summary,
47+
truncation=self.truncation,
48+
action_set=self.action_set,
49+
enable_safety_checks=self.enable_safety_checks,
50+
implicit_agreement=self.implicit_agreement
51+
)
52+
53+
54+
class OpenAIComputerUseAgent(Agent):
55+
def __init__(self,
56+
model: str,
57+
tool_type: str,
58+
display_width: int,
59+
display_height: int,
60+
environment: str,
61+
reasoning_summary: str,
62+
truncation: str,
63+
action_set: HighLevelActionSetArgs,
64+
enable_safety_checks: bool = False,
65+
implicit_agreement: bool = True
66+
):
67+
self.model = model
68+
self.reasoning_summary = reasoning_summary
69+
self.truncation = truncation
70+
self.enable_safety_checks = enable_safety_checks
71+
self.implicit_agreement = implicit_agreement
72+
73+
self.action_set = action_set.make_action_set()
74+
75+
assert not self.enable_safety_checks and\
76+
(self.action_set.demo_mode is not None or self.action_set.demo_mode != "off"), \
77+
"Safety checks are enabled but no demo mode is set. Please set demo_mode to 'all_blue' or 'off'."
78+
79+
self.computer_calls = []
80+
self.pending_checks = []
81+
self.previous_response_id = None
82+
self.last_call_id = None
83+
self.initialized = False # Set to True to call the API on the first get_action
84+
self.answer_assistant = None # Store the user answer to send to the assistant
85+
self.agent_info = AgentInfo()
86+
87+
self.tools = [
88+
{
89+
"type": tool_type,
90+
"display_width": display_width,
91+
"display_height": display_height,
92+
"environment": environment
93+
}
94+
]
95+
self.inputs = []
96+
97+
def parse_action_to_bgym(self, action) -> str:
98+
"""
99+
Parse the action string returned by the OpenAI API into bgym format.
100+
"""
101+
action_type = action.type
102+
103+
match(action_type):
104+
case "click":
105+
x, y = action.x, action.y
106+
button = action.button
107+
if button != "left" and button != "right":
108+
button = "left"
109+
return f"mouse_click({x}, {y}, button='{button}')"
110+
111+
case "scroll":
112+
x, y = action.x, action.y
113+
dx, dy = action.scroll_x, action.scroll_y
114+
return f"scroll_at({x}, {y}, {dx}, {dy})"
115+
116+
case "keypress":
117+
keys = action.keys
118+
for k in keys:
119+
if k.lower() == "enter":
120+
return "keyboard_press('Enter')"
121+
elif k.lower() == "space":
122+
return "keyboard_press(' ')"
123+
elif k.lower() == "ctrl":
124+
return "keyboard_press('Ctrl')"
125+
else:
126+
return f"keyboard_press('{k}')"
127+
128+
case "type":
129+
text = action.text
130+
return f"keyboard_insert_text('{text}')"
131+
132+
case "drag":
133+
from_x, from_y = action.path[0].x, action.path[0].y
134+
to_x, to_y = action.path[-1].x, action.path[-1].y
135+
return f"mouse_drag_and_drop({from_x}, {from_y}, {to_x}, {to_y})"
136+
137+
case "move":
138+
x, y = action.x, action.y
139+
return f"mouse_move({x}, {y})"
140+
141+
case "wait":
142+
return "noop(2000)" # wait for 2 seconds
143+
144+
# The screenshot is already given in the observation, so we don't need to handle it here.
145+
case "screenshot":
146+
# Nothing to do as screenshot is taken at each turn
147+
return "noop()"
148+
149+
case _:
150+
logging.error(f"No action found for {action_type}. Please check the action type.")
151+
return None
152+
153+
return action
154+
155+
def start_session(self, goal: str, screenshot_base64: str):
156+
"""
157+
Call the OpenAI API with the provided goal and screenshot to initiate a session.
158+
159+
Args:
160+
goal (str): The goal or task description for the agent.
161+
screenshot_base64 (str): Base64 encoded screenshot of the current state.
162+
Returns:
163+
response: A list of Response objects. The response from the OpenAI API.
164+
"""
165+
instruction = goal
166+
if self.implicit_agreement:
167+
instruction = f"""
168+
Please perform the following task, I can confirm it is safe and you can proceed with all actions without asking for confirmation.
169+
170+
Task:
171+
{goal}
172+
"""
173+
174+
response = self.call_api(
175+
input=[
176+
{
177+
"role": "user",
178+
"content": [
179+
{
180+
"type": "input_text",
181+
"text": instruction
182+
},
183+
{
184+
"type": "input_image",
185+
"image_url": f"{screenshot_base64}"
186+
}
187+
]
188+
}
189+
],
190+
reasoning={
191+
"summary": self.reasoning_summary,
192+
},
193+
)
194+
return response
195+
196+
def call_api(self, input: list, previous_response_id=None, **kwargs):
197+
response = client.responses.create(
198+
model=self.model,
199+
previous_response_id=previous_response_id,
200+
tools=self.tools,
201+
input=input,
202+
truncation=self.truncation, # Always set to "auto"
203+
**kwargs
204+
)
205+
return response
206+
207+
def get_action(self, obs):
208+
goal = obs["goal"]
209+
screenshot_base64 = image_to_jpg_base64_url(obs["screenshot"])
210+
211+
if not self.initialized:
212+
print("Initializing OpenAI Computer Use Agent with goal:", goal)
213+
response = self.start_session(goal, screenshot_base64)
214+
for item in response.output:
215+
if item.type == "reasoning":
216+
self.agent_info.think = item.summary[0].text if item.summary else None
217+
if item.type == "computer_call":
218+
self.computer_calls.append(item)
219+
self.previous_response_id = response.id
220+
self.initialized = True
221+
222+
if len(self.computer_calls) > 0:
223+
logging.debug("Found multiple computer calls in previous call. Processing them...")
224+
computer_call = self.computer_calls.pop(0)
225+
if not self.enable_safety_checks:
226+
# Bypass safety checks
227+
self.pending_checks = computer_call.pending_safety_checks
228+
print(f"Pending safety checks: {self.pending_checks}")
229+
action = self.parse_action_to_bgym(computer_call.action)
230+
self.last_call_id = computer_call.call_id
231+
return action, self.agent_info
232+
else:
233+
logging.debug("Last call ID:", self.last_call_id)
234+
logging.debug("Previous response ID:", self.previous_response_id)
235+
self.inputs.append(
236+
{
237+
"call_id": self.last_call_id,
238+
"type": "computer_call_output",
239+
"acknowledged_safety_checks": self.pending_checks,
240+
"output":
241+
{
242+
"type": "input_image",
243+
"image_url": f"{screenshot_base64}" # current screenshot
244+
},
245+
}
246+
)
247+
248+
if self.answer_assistant:
249+
self.inputs.append(self.answer_assistant)
250+
self.answer_assistant = None
251+
252+
response = self.call_api(self.inputs, self.previous_response_id)
253+
self.previous_response_id = response.id
254+
255+
self.computer_calls = [item for item in response.output if item.type == "computer_call"]
256+
if not self.computer_calls:
257+
logging.debug(f"No computer call found. Output from model: {response.output}")
258+
for item in response.output:
259+
if item.type == "reasoning":
260+
self.agent_info.think = item.summary[0].text if item.summary else None
261+
if hasattr(item, "role") and item.role == "assistant":
262+
# Assume assitant asked for user confirmation
263+
# Always answer with: Yes, continue.
264+
self.answer_assistant = {
265+
"role": "user",
266+
"content": [
267+
{
268+
"type": "input_text",
269+
"text": "Yes, continue."
270+
}
271+
]
272+
}
273+
return f"send_msg_to_user(\'{item.content[0].text}\')", self.agent_info
274+
logging.debug("No action found in the response. Returning None.")
275+
return None, self.agent_info
276+
277+
computer_call = self.computer_calls.pop(0)
278+
self.last_call_id = computer_call.call_id
279+
action = self.parse_action_to_bgym(computer_call.action)
280+
logging.debug("Action:", action)
281+
if not self.enable_safety_checks:
282+
# Bypass safety checks
283+
self.pending_checks = computer_call.pending_safety_checks
284+
else:
285+
pass
286+
# TODO: Handle safety checks if enabled in demo mode
287+
# self.pending_checks = computer_call.pending_safety_checks
288+
# for check in self.pending_checks:
289+
# do_something_to_acknowledge_check(check)
290+
291+
for item in response.output:
292+
if item.type == "reasoning":
293+
self.agent_info.think = item.summary[0].text if item.summary else None
294+
break
295+
296+
return action, self.agent_info
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from bgym import HighLevelActionSetArgs
2+
3+
from .agent import OpenAIComputerUseAgentArgs
4+
5+
OPENAI_CUA_AGENT_ARGS = OpenAIComputerUseAgentArgs(
6+
model="computer-use-preview",
7+
tool_type="computer_use_preview",
8+
display_width=1024,
9+
display_height=768,
10+
environment="browser",
11+
reasoning_summary="concise",
12+
truncation="auto",
13+
action_set=HighLevelActionSetArgs(
14+
subsets=("chat", "coord"),
15+
demo_mode=None,
16+
),
17+
enable_safety_checks=False,
18+
implicit_agreement=True
19+
)

0 commit comments

Comments
 (0)