1+ from dataclasses import dataclass
2+ import logging
3+
4+ from bgym import HighLevelActionSetArgs
5+ from browsergym .experiments import AbstractAgentArgs , Agent , AgentInfo
6+ from agentlab .llm .llm_utils import image_to_jpg_base64_url
7+
8+ import openai
9+ client = openai .OpenAI ()
10+
11+
12+ @dataclass
13+ class OpenAIComputerUseAgentArgs (AbstractAgentArgs ):
14+ """
15+ Arguments for the OpenAI Computer Use Agent.
16+ """
17+ agent_name : str = None
18+ model : str = "computer-use-preview"
19+ tool_type : str = "computer_use_preview"
20+ display_width : int = 1024
21+ display_height : int = 768
22+ environment : str = "browser"
23+ reasoning_summary : str = "concise"
24+ truncation : str = "auto" # Always set to "auto" for OpenAI API
25+ action_set : HighLevelActionSetArgs = None
26+ enable_safety_checks : bool = False # Optional, default to False, only use in demo mode
27+ implicit_agreement : bool = True # Whether to require explicit agreement for actions or not
28+
29+ def __post_init__ (self ):
30+ if self .agent_name is None :
31+ self .agent_name = "OpenAIComputerUseAgent"
32+
33+ def set_benchmark (self , benchmark , demo_mode ):
34+ pass
35+
36+ def set_reproducibility_mode (self ):
37+ pass
38+
39+ def make_agent (self ):
40+ return OpenAIComputerUseAgent (
41+ model = self .model ,
42+ tool_type = self .tool_type ,
43+ display_width = self .display_width ,
44+ display_height = self .display_height ,
45+ environment = self .environment ,
46+ reasoning_summary = self .reasoning_summary ,
47+ truncation = self .truncation ,
48+ action_set = self .action_set ,
49+ enable_safety_checks = self .enable_safety_checks ,
50+ implicit_agreement = self .implicit_agreement
51+ )
52+
53+
54+ class OpenAIComputerUseAgent (Agent ):
55+ def __init__ (self ,
56+ model : str ,
57+ tool_type : str ,
58+ display_width : int ,
59+ display_height : int ,
60+ environment : str ,
61+ reasoning_summary : str ,
62+ truncation : str ,
63+ action_set : HighLevelActionSetArgs ,
64+ enable_safety_checks : bool = False ,
65+ implicit_agreement : bool = True
66+ ):
67+ self .model = model
68+ self .reasoning_summary = reasoning_summary
69+ self .truncation = truncation
70+ self .enable_safety_checks = enable_safety_checks
71+ self .implicit_agreement = implicit_agreement
72+
73+ self .action_set = action_set .make_action_set ()
74+
75+ assert not self .enable_safety_checks and \
76+ (self .action_set .demo_mode is not None or self .action_set .demo_mode != "off" ), \
77+ "Safety checks are enabled but no demo mode is set. Please set demo_mode to 'all_blue' or 'off'."
78+
79+ self .computer_calls = []
80+ self .pending_checks = []
81+ self .previous_response_id = None
82+ self .last_call_id = None
83+ self .initialized = False # Set to True to call the API on the first get_action
84+ self .answer_assistant = None # Store the user answer to send to the assistant
85+ self .agent_info = AgentInfo ()
86+
87+ self .tools = [
88+ {
89+ "type" : tool_type ,
90+ "display_width" : display_width ,
91+ "display_height" : display_height ,
92+ "environment" : environment
93+ }
94+ ]
95+ self .inputs = []
96+
97+ def parse_action_to_bgym (self , action ) -> str :
98+ """
99+ Parse the action string returned by the OpenAI API into bgym format.
100+ """
101+ action_type = action .type
102+
103+ match (action_type ):
104+ case "click" :
105+ x , y = action .x , action .y
106+ button = action .button
107+ if button != "left" and button != "right" :
108+ button = "left"
109+ return f"mouse_click({ x } , { y } , button='{ button } ')"
110+
111+ case "scroll" :
112+ x , y = action .x , action .y
113+ dx , dy = action .scroll_x , action .scroll_y
114+ return f"scroll_at({ x } , { y } , { dx } , { dy } )"
115+
116+ case "keypress" :
117+ keys = action .keys
118+ for k in keys :
119+ if k .lower () == "enter" :
120+ return "keyboard_press('Enter')"
121+ elif k .lower () == "space" :
122+ return "keyboard_press(' ')"
123+ elif k .lower () == "ctrl" :
124+ return "keyboard_press('Ctrl')"
125+ else :
126+ return f"keyboard_press('{ k } ')"
127+
128+ case "type" :
129+ text = action .text
130+ return f"keyboard_insert_text('{ text } ')"
131+
132+ case "drag" :
133+ from_x , from_y = action .path [0 ].x , action .path [0 ].y
134+ to_x , to_y = action .path [- 1 ].x , action .path [- 1 ].y
135+ return f"mouse_drag_and_drop({ from_x } , { from_y } , { to_x } , { to_y } )"
136+
137+ case "move" :
138+ x , y = action .x , action .y
139+ return f"mouse_move({ x } , { y } )"
140+
141+ case "wait" :
142+ return "noop(2000)" # wait for 2 seconds
143+
144+ # The screenshot is already given in the observation, so we don't need to handle it here.
145+ case "screenshot" :
146+ # Nothing to do as screenshot is taken at each turn
147+ return "noop()"
148+
149+ case _:
150+ logging .error (f"No action found for { action_type } . Please check the action type." )
151+ return None
152+
153+ return action
154+
155+ def start_session (self , goal : str , screenshot_base64 : str ):
156+ """
157+ Call the OpenAI API with the provided goal and screenshot to initiate a session.
158+
159+ Args:
160+ goal (str): The goal or task description for the agent.
161+ screenshot_base64 (str): Base64 encoded screenshot of the current state.
162+ Returns:
163+ response: A list of Response objects. The response from the OpenAI API.
164+ """
165+ instruction = goal
166+ if self .implicit_agreement :
167+ instruction = f"""
168+ Please perform the following task, I can confirm it is safe and you can proceed with all actions without asking for confirmation.
169+
170+ Task:
171+ { goal }
172+ """
173+
174+ response = self .call_api (
175+ input = [
176+ {
177+ "role" : "user" ,
178+ "content" : [
179+ {
180+ "type" : "input_text" ,
181+ "text" : instruction
182+ },
183+ {
184+ "type" : "input_image" ,
185+ "image_url" : f"{ screenshot_base64 } "
186+ }
187+ ]
188+ }
189+ ],
190+ reasoning = {
191+ "summary" : self .reasoning_summary ,
192+ },
193+ )
194+ return response
195+
196+ def call_api (self , input : list , previous_response_id = None , ** kwargs ):
197+ response = client .responses .create (
198+ model = self .model ,
199+ previous_response_id = previous_response_id ,
200+ tools = self .tools ,
201+ input = input ,
202+ truncation = self .truncation , # Always set to "auto"
203+ ** kwargs
204+ )
205+ return response
206+
207+ def get_action (self , obs ):
208+ goal = obs ["goal" ]
209+ screenshot_base64 = image_to_jpg_base64_url (obs ["screenshot" ])
210+
211+ if not self .initialized :
212+ print ("Initializing OpenAI Computer Use Agent with goal:" , goal )
213+ response = self .start_session (goal , screenshot_base64 )
214+ for item in response .output :
215+ if item .type == "reasoning" :
216+ self .agent_info .think = item .summary [0 ].text if item .summary else None
217+ if item .type == "computer_call" :
218+ self .computer_calls .append (item )
219+ self .previous_response_id = response .id
220+ self .initialized = True
221+
222+ if len (self .computer_calls ) > 0 :
223+ logging .debug ("Found multiple computer calls in previous call. Processing them..." )
224+ computer_call = self .computer_calls .pop (0 )
225+ if not self .enable_safety_checks :
226+ # Bypass safety checks
227+ self .pending_checks = computer_call .pending_safety_checks
228+ print (f"Pending safety checks: { self .pending_checks } " )
229+ action = self .parse_action_to_bgym (computer_call .action )
230+ self .last_call_id = computer_call .call_id
231+ return action , self .agent_info
232+ else :
233+ logging .debug ("Last call ID:" , self .last_call_id )
234+ logging .debug ("Previous response ID:" , self .previous_response_id )
235+ self .inputs .append (
236+ {
237+ "call_id" : self .last_call_id ,
238+ "type" : "computer_call_output" ,
239+ "acknowledged_safety_checks" : self .pending_checks ,
240+ "output" :
241+ {
242+ "type" : "input_image" ,
243+ "image_url" : f"{ screenshot_base64 } " # current screenshot
244+ },
245+ }
246+ )
247+
248+ if self .answer_assistant :
249+ self .inputs .append (self .answer_assistant )
250+ self .answer_assistant = None
251+
252+ response = self .call_api (self .inputs , self .previous_response_id )
253+ self .previous_response_id = response .id
254+
255+ self .computer_calls = [item for item in response .output if item .type == "computer_call" ]
256+ if not self .computer_calls :
257+ logging .debug (f"No computer call found. Output from model: { response .output } " )
258+ for item in response .output :
259+ if item .type == "reasoning" :
260+ self .agent_info .think = item .summary [0 ].text if item .summary else None
261+ if hasattr (item , "role" ) and item .role == "assistant" :
262+ # Assume assitant asked for user confirmation
263+ # Always answer with: Yes, continue.
264+ self .answer_assistant = {
265+ "role" : "user" ,
266+ "content" : [
267+ {
268+ "type" : "input_text" ,
269+ "text" : "Yes, continue."
270+ }
271+ ]
272+ }
273+ return f"send_msg_to_user(\' { item .content [0 ].text } \' )" , self .agent_info
274+ logging .debug ("No action found in the response. Returning None." )
275+ return None , self .agent_info
276+
277+ computer_call = self .computer_calls .pop (0 )
278+ self .last_call_id = computer_call .call_id
279+ action = self .parse_action_to_bgym (computer_call .action )
280+ logging .debug ("Action:" , action )
281+ if not self .enable_safety_checks :
282+ # Bypass safety checks
283+ self .pending_checks = computer_call .pending_safety_checks
284+ else :
285+ pass
286+ # TODO: Handle safety checks if enabled in demo mode
287+ # self.pending_checks = computer_call.pending_safety_checks
288+ # for check in self.pending_checks:
289+ # do_something_to_acknowledge_check(check)
290+
291+ for item in response .output :
292+ if item .type == "reasoning" :
293+ self .agent_info .think = item .summary [0 ].text if item .summary else None
294+ break
295+
296+ return action , self .agent_info
0 commit comments