66from agentlab .llm .llm_utils import image_to_jpg_base64_url
77
88import openai
9+
910client = openai .OpenAI ()
1011
1112
@@ -14,15 +15,16 @@ class OpenAIComputerUseAgentArgs(AbstractAgentArgs):
1415 """
1516 Arguments for the OpenAI Computer Use Agent.
1617 """
17- agent_name : str = None
18- model : str = "computer-use-preview"
18+
19+ agent_name : str = None
20+ model : str = "computer-use-preview"
1921 tool_type : str = "computer_use_preview"
2022 display_width : int = 1024
2123 display_height : int = 768
2224 environment : str = "browser"
2325 reasoning_summary : str = "concise"
2426 truncation : str = "auto" # Always set to "auto" for OpenAI API
25- action_set : HighLevelActionSetArgs = None
27+ action_set : HighLevelActionSetArgs = None
2628 enable_safety_checks : bool = False # Optional, default to False, only use in demo mode
2729 implicit_agreement : bool = True # Whether to require explicit agreement for actions or not
2830
@@ -47,23 +49,24 @@ def make_agent(self):
4749 truncation = self .truncation ,
4850 action_set = self .action_set ,
4951 enable_safety_checks = self .enable_safety_checks ,
50- implicit_agreement = self .implicit_agreement
52+ implicit_agreement = self .implicit_agreement ,
5153 )
5254
5355
5456class OpenAIComputerUseAgent (Agent ):
55- def __init__ (self ,
56- model : str ,
57- tool_type : str ,
58- display_width : int ,
59- display_height : int ,
60- environment : str ,
61- reasoning_summary : str ,
62- truncation : str ,
63- action_set : HighLevelActionSetArgs ,
64- enable_safety_checks : bool = False ,
65- implicit_agreement : bool = True
66- ):
57+ def __init__ (
58+ self ,
59+ model : str ,
60+ tool_type : str ,
61+ display_width : int ,
62+ display_height : int ,
63+ environment : str ,
64+ reasoning_summary : str ,
65+ truncation : str ,
66+ action_set : HighLevelActionSetArgs ,
67+ enable_safety_checks : bool = False ,
68+ implicit_agreement : bool = True ,
69+ ):
6770 self .model = model
6871 self .reasoning_summary = reasoning_summary
6972 self .truncation = truncation
@@ -72,24 +75,24 @@ def __init__(self,
7275
7376 self .action_set = action_set .make_action_set ()
7477
75- assert not self .enable_safety_checks and \
76- ( self .action_set .demo_mode is not None or self .action_set .demo_mode != "off" ), \
77- "Safety checks are enabled but no demo mode is set. Please set demo_mode to 'all_blue' or 'off'."
78+ assert not self .enable_safety_checks and (
79+ self .action_set .demo_mode is not None or self .action_set .demo_mode != "off"
80+ ), "Safety checks are enabled but no demo mode is set. Please set demo_mode to 'all_blue' or 'off'."
7881
7982 self .computer_calls = []
8083 self .pending_checks = []
81- self .previous_response_id = None
82- self .last_call_id = None
84+ self .previous_response_id = None
85+ self .last_call_id = None
8386 self .initialized = False # Set to True to call the API on the first get_action
84- self .answer_assistant = None # Store the user answer to send to the assistant
87+ self .answer_assistant = None # Store the user answer to send to the assistant
8588 self .agent_info = AgentInfo ()
8689
8790 self .tools = [
8891 {
8992 "type" : tool_type ,
9093 "display_width" : display_width ,
9194 "display_height" : display_height ,
92- "environment" : environment
95+ "environment" : environment ,
9396 }
9497 ]
9598 self .inputs = []
@@ -100,7 +103,7 @@ def parse_action_to_bgym(self, action) -> str:
100103 """
101104 action_type = action .type
102105
103- match (action_type ):
106+ match (action_type ):
104107 case "click" :
105108 x , y = action .x , action .y
106109 button = action .button
@@ -124,11 +127,11 @@ def parse_action_to_bgym(self, action) -> str:
124127 return "keyboard_press('Ctrl')"
125128 else :
126129 return f"keyboard_press('{ k } ')"
127-
130+
128131 case "type" :
129132 text = action .text
130133 return f"keyboard_insert_text('{ text } ')"
131-
134+
132135 case "drag" :
133136 from_x , from_y = action .path [0 ].x , action .path [0 ].y
134137 to_x , to_y = action .path [- 1 ].x , action .path [- 1 ].y
@@ -139,7 +142,7 @@ def parse_action_to_bgym(self, action) -> str:
139142 return f"mouse_move({ x } , { y } )"
140143
141144 case "wait" :
142- return "noop(2000)" # wait for 2 seconds
145+ return "noop(2000)" # wait for 2 seconds
143146
144147 # The screenshot is already given in the observation, so we don't need to handle it here.
145148 case "screenshot" :
@@ -149,7 +152,7 @@ def parse_action_to_bgym(self, action) -> str:
149152 case _:
150153 logging .error (f"No action found for { action_type } . Please check the action type." )
151154 return None
152-
155+
153156 return action
154157
155158 def start_session (self , goal : str , screenshot_base64 : str ):
@@ -174,17 +177,11 @@ def start_session(self, goal: str, screenshot_base64: str):
174177 response = self .call_api (
175178 input = [
176179 {
177- "role" : "user" ,
178- "content" : [
179- {
180- "type" : "input_text" ,
181- "text" : instruction
182- },
183- {
184- "type" : "input_image" ,
185- "image_url" : f"{ screenshot_base64 } "
186- }
187- ]
180+ "role" : "user" ,
181+ "content" : [
182+ {"type" : "input_text" , "text" : instruction },
183+ {"type" : "input_image" , "image_url" : f"{ screenshot_base64 } " },
184+ ],
188185 }
189186 ],
190187 reasoning = {
@@ -199,8 +196,8 @@ def call_api(self, input: list, previous_response_id=None, **kwargs):
199196 previous_response_id = previous_response_id ,
200197 tools = self .tools ,
201198 input = input ,
202- truncation = self .truncation , # Always set to "auto"
203- ** kwargs
199+ truncation = self .truncation , # Always set to "auto"
200+ ** kwargs ,
204201 )
205202 return response
206203
@@ -218,7 +215,7 @@ def get_action(self, obs):
218215 self .computer_calls .append (item )
219216 self .previous_response_id = response .id
220217 self .initialized = True
221-
218+
222219 if len (self .computer_calls ) > 0 :
223220 logging .debug ("Found multiple computer calls in previous call. Processing them..." )
224221 computer_call = self .computer_calls .pop (0 )
@@ -237,18 +234,17 @@ def get_action(self, obs):
237234 "call_id" : self .last_call_id ,
238235 "type" : "computer_call_output" ,
239236 "acknowledged_safety_checks" : self .pending_checks ,
240- "output" :
241- {
242- "type" : "input_image" ,
243- "image_url" : f"{ screenshot_base64 } " # current screenshot
244- },
237+ "output" : {
238+ "type" : "input_image" ,
239+ "image_url" : f"{ screenshot_base64 } " , # current screenshot
240+ },
245241 }
246242 )
247243
248244 if self .answer_assistant :
249245 self .inputs .append (self .answer_assistant )
250246 self .answer_assistant = None
251-
247+
252248 response = self .call_api (self .inputs , self .previous_response_id )
253249 self .previous_response_id = response .id
254250
@@ -263,17 +259,12 @@ def get_action(self, obs):
263259 # Always answer with: Yes, continue.
264260 self .answer_assistant = {
265261 "role" : "user" ,
266- "content" : [
267- {
268- "type" : "input_text" ,
269- "text" : "Yes, continue."
270- }
271- ]
262+ "content" : [{"type" : "input_text" , "text" : "Yes, continue." }],
272263 }
273- return f"send_msg_to_user(\ '{ item .content [0 ].text } \ ' )" , self .agent_info
264+ return f"send_msg_to_user('{ item .content [0 ].text } ')" , self .agent_info
274265 logging .debug ("No action found in the response. Returning None." )
275266 return None , self .agent_info
276-
267+
277268 computer_call = self .computer_calls .pop (0 )
278269 self .last_call_id = computer_call .call_id
279270 action = self .parse_action_to_bgym (computer_call .action )
@@ -293,4 +284,4 @@ def get_action(self, obs):
293284 self .agent_info .think = item .summary [0 ].text if item .summary else None
294285 break
295286
296- return action , self .agent_info
287+ return action , self .agent_info
0 commit comments