55
66import bgym
77import numpy as np
8- from browsergym .core .observation import extract_screenshot
98from PIL import Image , ImageDraw
109
10+ from agentlab .agents import agent_utils
1111from agentlab .agents .agent_args import AgentArgs
1212from agentlab .llm .llm_utils import image_to_png_base64_url
1313from agentlab .llm .response_api import (
1414 ClaudeResponseModelArgs ,
1515 MessageBuilder ,
16+ OpenAIChatModelArgs ,
1617 OpenAIResponseModelArgs ,
18+ OpenRouterModelArgs ,
1719 ResponseLLMOutput ,
1820)
1921from agentlab .llm .tracking import cost_tracker_decorator
22+ from browsergym .core .observation import extract_screenshot
2023
2124if TYPE_CHECKING :
2225 from openai .types .responses import Response
2326
2427
25- def tag_screenshot_with_action (screenshot : Image , action : str ) -> Image :
26- """
27- If action is a coordinate action, try to render it on the screenshot.
28-
29- e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot
30-
31- Args:
32- screenshot: The screenshot to tag.
33- action: The action to tag the screenshot with.
34-
35- Returns:
36- The tagged screenshot.
37-
38- Raises:
39- ValueError: If the action parsing fails.
40- """
41- if action .startswith ("mouse_click" ):
42- try :
43- coords = action [action .index ("(" ) + 1 : action .index (")" )].split ("," )
44- coords = [c .strip () for c in coords ]
45- if len (coords ) != 2 :
46- raise ValueError (f"Invalid coordinate format: { coords } " )
47- if coords [0 ].startswith ("x=" ):
48- coords [0 ] = coords [0 ][2 :]
49- if coords [1 ].startswith ("y=" ):
50- coords [1 ] = coords [1 ][2 :]
51- x , y = float (coords [0 ].strip ()), float (coords [1 ].strip ())
52- draw = ImageDraw .Draw (screenshot )
53- radius = 5
54- draw .ellipse (
55- (x - radius , y - radius , x + radius , y + radius ), fill = "red" , outline = "red"
56- )
57- except (ValueError , IndexError ) as e :
58- logging .warning (f"Failed to parse action '{ action } ': { e } " )
59- return screenshot
60-
61-
6228@dataclass
6329class ToolUseAgentArgs (AgentArgs ):
6430 model_args : OpenAIResponseModelArgs = None
@@ -97,19 +63,9 @@ def __init__(
9763 self .model_args = model_args
9864 self .use_first_obs = use_first_obs
9965 self .tag_screenshot = tag_screenshot
100-
10166 self .action_set = bgym .HighLevelActionSet (["coord" ], multiaction = False )
102-
10367 self .tools = self .action_set .to_tool_description (api = model_args .api )
10468
105- # count tools tokens
106- from agentlab .llm .llm_utils import count_tokens
107-
108- tool_str = json .dumps (self .tools , indent = 2 )
109- print (f"Tool description: { tool_str } " )
110- tool_tokens = count_tokens (tool_str , model_args .model_name )
111- print (f"Tool tokens: { tool_tokens } " )
112-
11369 self .call_ids = []
11470
11571 # self.tools.append(
@@ -131,7 +87,7 @@ def __init__(
13187 # )
13288
13389 self .llm = model_args .make_model (extra_kwargs = {"tools" : self .tools })
134-
90+ self . msg_builder = model_args . get_message_builder ()
13591 self .messages : list [MessageBuilder ] = []
13692
13793 def obs_preprocessor (self , obs ):
@@ -140,7 +96,7 @@ def obs_preprocessor(self, obs):
14096 obs ["screenshot" ] = extract_screenshot (page )
14197 if self .tag_screenshot :
14298 screenshot = Image .fromarray (obs ["screenshot" ])
143- screenshot = tag_screenshot_with_action (screenshot , obs ["last_action" ])
99+ screenshot = agent_utils . tag_screenshot_with_action (screenshot , obs ["last_action" ])
144100 obs ["screenshot_tag" ] = np .array (screenshot )
145101 else :
146102 raise ValueError ("No page found in the observation." )
@@ -150,56 +106,31 @@ def obs_preprocessor(self, obs):
150106 @cost_tracker_decorator
151107 def get_action (self , obs : Any ) -> float :
152108 if len (self .messages ) == 0 :
153- system_message = MessageBuilder .system ().add_text (
154- "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
155- )
156- self .messages .append (system_message )
157-
158- goal_message = MessageBuilder .user ()
159- for content in obs ["goal_object" ]:
160- if content ["type" ] == "text" :
161- goal_message .add_text (content ["text" ])
162- elif content ["type" ] == "image_url" :
163- goal_message .add_image (content ["image_url" ])
164- self .messages .append (goal_message )
165-
166- extra_info = []
167-
168- extra_info .append (
169- """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n """
170- )
171-
172- self .messages .append (MessageBuilder .user ().add_text ("\n " .join (extra_info )))
173-
174- if self .use_first_obs :
175- msg = "Here is the first observation."
176- screenshot_key = "screenshot_tag" if self .tag_screenshot else "screenshot"
177- if self .tag_screenshot :
178- msg += " A red dot on screenshots indicate the previous click action."
179- message = MessageBuilder .user ().add_text (msg )
180- message .add_image (image_to_png_base64_url (obs [screenshot_key ]))
181- self .messages .append (message )
109+ self .initalize_messages (obs )
182110 else :
183- if obs ["last_action_error" ] == "" :
111+ if obs ["last_action_error" ] == "" : # Check No error in the last action
184112 screenshot_key = "screenshot_tag" if self .tag_screenshot else "screenshot"
185- tool_message = MessageBuilder .tool ().add_image (
113+ tool_message = self . msg_builder .tool ().add_image (
186114 image_to_png_base64_url (obs [screenshot_key ])
187115 )
116+ tool_message .update_last_raw_response (self .last_response )
188117 tool_message .add_tool_id (self .previous_call_id )
189118 self .messages .append (tool_message )
190119 else :
191- tool_message = MessageBuilder .tool ().add_text (
120+ tool_message = self . msg_builder .tool ().add_text (
192121 f"Function call failed: { obs ['last_action_error' ]} "
193122 )
194123 tool_message .add_tool_id (self .previous_call_id )
124+ tool_message .update_last_raw_response (self .last_response )
195125 self .messages .append (tool_message )
196126
197127 response : ResponseLLMOutput = self .llm (messages = self .messages )
198128
199129 action = response .action
200130 think = response .think
131+ self .last_response = response
201132 self .previous_call_id = response .last_computer_call_id
202- self .messages .append (response .assistant_message )
133+ self .messages .append (response .assistant_message ) # this is tool call
203134
204135 return (
205136 action ,
@@ -210,6 +141,37 @@ def get_action(self, obs: Any) -> float:
210141 ),
211142 )
212143
144+ def initalize_messages (self , obs : Any ) -> None :
145+ system_message = self .msg_builder .system ().add_text (
146+ "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
147+ )
148+ self .messages .append (system_message )
149+
150+ goal_message = self .msg_builder .user ()
151+ for content in obs ["goal_object" ]:
152+ if content ["type" ] == "text" :
153+ goal_message .add_text (content ["text" ])
154+ elif content ["type" ] == "image_url" :
155+ goal_message .add_image (content ["image_url" ])
156+ self .messages .append (goal_message )
157+
158+ extra_info = []
159+
160+ extra_info .append (
161+ """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n """
162+ )
163+
164+ self .messages .append (self .msg_builder .user ().add_text ("\n " .join (extra_info )))
165+
166+ if self .use_first_obs :
167+ msg = "Here is the first observation."
168+ screenshot_key = "screenshot_tag" if self .tag_screenshot else "screenshot"
169+ if self .tag_screenshot :
170+ msg += " A red dot on screenshots indicate the previous click action."
171+ message = self .msg_builder .user ().add_text (msg )
172+ message .add_image (image_to_png_base64_url (obs [screenshot_key ]))
173+ self .messages .append (message )
174+
213175
214176OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs (
215177 model_name = "gpt-4.1" ,
@@ -220,6 +182,14 @@ def get_action(self, obs: Any) -> float:
220182 vision_support = True ,
221183)
222184
185+ OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs (
186+ model_name = "gpt-4o-2024-08-06" ,
187+ max_total_tokens = 200_000 ,
188+ max_input_tokens = 200_000 ,
189+ max_new_tokens = 2_000 ,
190+ temperature = 0.1 ,
191+ vision_support = True ,
192+ )
223193
224194CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs (
225195 model_name = "claude-3-7-sonnet-20250219" ,
@@ -231,6 +201,114 @@ def get_action(self, obs: Any) -> float:
231201)
232202
233203
204+ def supports_tool_calling (model_name : str ) -> bool :
205+ """
206+ Check if the model supports tool calling.
207+
208+ Args:
209+ model_name (str): The name of the model.
210+
211+ Returns:
212+ bool: True if the model supports tool calling, False otherwise.
213+ """
214+ import os
215+
216+ import openai
217+
218+ client = openai .Client (
219+ api_key = os .getenv ("OPENROUTER_API_KEY" ), base_url = "https://openrouter.ai/api/v1"
220+ )
221+ try :
222+ response = client .chat .completions .create (
223+ model = model_name ,
224+ messages = [{"role" : "user" , "content" : "Call the test tool" }],
225+ tools = [
226+ {
227+ "type" : "function" ,
228+ "function" : {
229+ "name" : "dummy_tool" ,
230+ "description" : "Just a test tool" ,
231+ "parameters" : {
232+ "type" : "object" ,
233+ "properties" : {},
234+ },
235+ },
236+ }
237+ ],
238+ tool_choice = "required" ,
239+ )
240+ response = response .to_dict ()
241+ return "tool_calls" in response ["choices" ][0 ]["message" ]
242+ except Exception as e :
243+ print (f"Model '{ model_name } ' error: { e } " )
244+ return False
245+
246+
247+ def get_openrouter_model (model_name : str , ** open_router_args ) -> OpenRouterModelArgs :
248+ default_model_args = {
249+ "max_total_tokens" : 200_000 ,
250+ "max_input_tokens" : 180_000 ,
251+ "max_new_tokens" : 2_000 ,
252+ "temperature" : 0.1 ,
253+ "vision_support" : True ,
254+ }
255+ merged_args = {** default_model_args , ** open_router_args }
256+
257+ return OpenRouterModelArgs (model_name = model_name , ** merged_args )
258+
259+
260+ def get_openrouter_tool_use_agent (
261+ model_name : str ,
262+ model_args : dict = {},
263+ use_first_obs = True ,
264+ tag_screenshot = True ,
265+ use_raw_page_output = True ,
266+ ) -> ToolUseAgentArgs :
267+ #To Do : Check if OpenRouter endpoint specific args are working
268+ if not supports_tool_calling (model_name ):
269+ raise ValueError (f"Model { model_name } does not support tool calling." )
270+
271+ model_args = get_openrouter_model (model_name , ** model_args )
272+
273+ return ToolUseAgentArgs (
274+ model_args = model_args ,
275+ use_first_obs = use_first_obs ,
276+ tag_screenshot = tag_screenshot ,
277+ use_raw_page_output = use_raw_page_output ,
278+ )
279+
280+
281+ OPENROUTER_MODEL = get_openrouter_tool_use_agent ("google/gemini-2.5-pro-preview" )
282+
283+
234284AGENT_CONFIG = ToolUseAgentArgs (
235285 model_args = CLAUDE_MODEL_CONFIG ,
236286)
287+
288+ MT_TOOL_USE_AGENT = ToolUseAgentArgs (
289+ model_args = OPENROUTER_MODEL ,
290+ )
291+ CHATAPI_AGENT_CONFIG = ToolUseAgentArgs (
292+ model_args = OpenAIChatModelArgs (
293+ model_name = "gpt-4o-2024-11-20" ,
294+ max_total_tokens = 200_000 ,
295+ max_input_tokens = 200_000 ,
296+ max_new_tokens = 2_000 ,
297+ temperature = 0.7 ,
298+ vision_support = True ,
299+ ),
300+ )
301+
302+
303+ OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs (
304+ model_args = OpenAIChatModelArgs (model_name = "gpt-4o-2024-08-06" ),
305+ use_first_obs = False ,
306+ tag_screenshot = False ,
307+ use_raw_page_output = True ,
308+ )
309+
310+
311+ ## We have three providers that we want to support.
312+ # Anthropic
313+ # OpenAI
314+ # vllm (uses OpenAI API)
0 commit comments