55
66import bgym
77import numpy as np
8- from browsergym .core .observation import extract_screenshot
98from PIL import Image , ImageDraw
109
10+ from agentlab .agents import agent_utils
1111from agentlab .agents .agent_args import AgentArgs
1212from agentlab .llm .llm_utils import image_to_png_base64_url
1313from agentlab .llm .response_api import (
14+ BaseModelArgs ,
1415 ClaudeResponseModelArgs ,
1516 MessageBuilder ,
17+ OpenAIChatModelArgs ,
1618 OpenAIResponseModelArgs ,
19+ OpenRouterModelArgs ,
1720 ResponseLLMOutput ,
21+ VLLMModelArgs ,
1822)
1923from agentlab .llm .tracking import cost_tracker_decorator
24+ from browsergym .core .observation import extract_screenshot
2025
2126if TYPE_CHECKING :
2227 from openai .types .responses import Response
2328
2429
25- def tag_screenshot_with_action (screenshot : Image , action : str ) -> Image :
26- """
27- If action is a coordinate action, try to render it on the screenshot.
28-
29- e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot
30-
31- Args:
32- screenshot: The screenshot to tag.
33- action: The action to tag the screenshot with.
34-
35- Returns:
36- The tagged screenshot.
37-
38- Raises:
39- ValueError: If the action parsing fails.
40- """
41- if action .startswith ("mouse_click" ):
42- try :
43- coords = action [action .index ("(" ) + 1 : action .index (")" )].split ("," )
44- coords = [c .strip () for c in coords ]
45- if len (coords ) != 2 :
46- raise ValueError (f"Invalid coordinate format: { coords } " )
47- if coords [0 ].startswith ("x=" ):
48- coords [0 ] = coords [0 ][2 :]
49- if coords [1 ].startswith ("y=" ):
50- coords [1 ] = coords [1 ][2 :]
51- x , y = float (coords [0 ].strip ()), float (coords [1 ].strip ())
52- draw = ImageDraw .Draw (screenshot )
53- radius = 5
54- draw .ellipse (
55- (x - radius , y - radius , x + radius , y + radius ), fill = "red" , outline = "red"
56- )
57- except (ValueError , IndexError ) as e :
58- logging .warning (f"Failed to parse action '{ action } ': { e } " )
59- return screenshot
60-
61-
6230@dataclass
6331class ToolUseAgentArgs (AgentArgs ):
6432 model_args : OpenAIResponseModelArgs = None
@@ -97,19 +65,9 @@ def __init__(
9765 self .model_args = model_args
9866 self .use_first_obs = use_first_obs
9967 self .tag_screenshot = tag_screenshot
100-
10168 self .action_set = bgym .HighLevelActionSet (["coord" ], multiaction = False )
102-
10369 self .tools = self .action_set .to_tool_description (api = model_args .api )
10470
105- # count tools tokens
106- from agentlab .llm .llm_utils import count_tokens
107-
108- tool_str = json .dumps (self .tools , indent = 2 )
109- print (f"Tool description: { tool_str } " )
110- tool_tokens = count_tokens (tool_str , model_args .model_name )
111- print (f"Tool tokens: { tool_tokens } " )
112-
11371 self .call_ids = []
11472
11573 # self.tools.append(
@@ -131,7 +89,7 @@ def __init__(
13189 # )
13290
13391 self .llm = model_args .make_model (extra_kwargs = {"tools" : self .tools })
134-
92+ self . msg_builder = model_args . get_message_builder ()
13593 self .messages : list [MessageBuilder ] = []
13694
13795 def obs_preprocessor (self , obs ):
@@ -140,7 +98,7 @@ def obs_preprocessor(self, obs):
14098 obs ["screenshot" ] = extract_screenshot (page )
14199 if self .tag_screenshot :
142100 screenshot = Image .fromarray (obs ["screenshot" ])
143- screenshot = tag_screenshot_with_action (screenshot , obs ["last_action" ])
101+ screenshot = agent_utils . tag_screenshot_with_action (screenshot , obs ["last_action" ])
144102 obs ["screenshot_tag" ] = np .array (screenshot )
145103 else :
146104 raise ValueError ("No page found in the observation." )
@@ -150,56 +108,31 @@ def obs_preprocessor(self, obs):
150108 @cost_tracker_decorator
151109 def get_action (self , obs : Any ) -> float :
152110 if len (self .messages ) == 0 :
153- system_message = MessageBuilder .system ().add_text (
154- "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
155- )
156- self .messages .append (system_message )
157-
158- goal_message = MessageBuilder .user ()
159- for content in obs ["goal_object" ]:
160- if content ["type" ] == "text" :
161- goal_message .add_text (content ["text" ])
162- elif content ["type" ] == "image_url" :
163- goal_message .add_image (content ["image_url" ])
164- self .messages .append (goal_message )
165-
166- extra_info = []
167-
168- extra_info .append (
169- """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n """
170- )
171-
172- self .messages .append (MessageBuilder .user ().add_text ("\n " .join (extra_info )))
173-
174- if self .use_first_obs :
175- msg = "Here is the first observation."
176- screenshot_key = "screenshot_tag" if self .tag_screenshot else "screenshot"
177- if self .tag_screenshot :
178- msg += " A red dot on screenshots indicate the previous click action."
179- message = MessageBuilder .user ().add_text (msg )
180- message .add_image (image_to_png_base64_url (obs [screenshot_key ]))
181- self .messages .append (message )
111+ self .initalize_messages (obs )
182112 else :
183- if obs ["last_action_error" ] == "" :
113+ if obs ["last_action_error" ] == "" : # Check No error in the last action
184114 screenshot_key = "screenshot_tag" if self .tag_screenshot else "screenshot"
185- tool_message = MessageBuilder .tool ().add_image (
115+ tool_message = self . msg_builder .tool ().add_image (
186116 image_to_png_base64_url (obs [screenshot_key ])
187117 )
118+ tool_message .update_last_raw_response (self .last_response )
188119 tool_message .add_tool_id (self .previous_call_id )
189120 self .messages .append (tool_message )
190121 else :
191- tool_message = MessageBuilder .tool ().add_text (
122+ tool_message = self . msg_builder .tool ().add_text (
192123 f"Function call failed: { obs ['last_action_error' ]} "
193124 )
194125 tool_message .add_tool_id (self .previous_call_id )
126+ tool_message .update_last_raw_response (self .last_response )
195127 self .messages .append (tool_message )
196128
197129 response : ResponseLLMOutput = self .llm (messages = self .messages )
198130
199131 action = response .action
200132 think = response .think
133+ self .last_response = response
201134 self .previous_call_id = response .last_computer_call_id
202- self .messages .append (response .assistant_message )
135+ self .messages .append (response .assistant_message ) # this is tool call
203136
204137 return (
205138 action ,
@@ -210,6 +143,37 @@ def get_action(self, obs: Any) -> float:
210143 ),
211144 )
212145
146+ def initalize_messages (self , obs : Any ) -> None :
147+ system_message = self .msg_builder .system ().add_text (
148+ "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
149+ )
150+ self .messages .append (system_message )
151+
152+ goal_message = self .msg_builder .user ()
153+ for content in obs ["goal_object" ]:
154+ if content ["type" ] == "text" :
155+ goal_message .add_text (content ["text" ])
156+ elif content ["type" ] == "image_url" :
157+ goal_message .add_image (content ["image_url" ])
158+ self .messages .append (goal_message )
159+
160+ extra_info = []
161+
162+ extra_info .append (
163+ """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n """
164+ )
165+
166+ self .messages .append (self .msg_builder .user ().add_text ("\n " .join (extra_info )))
167+
168+ if self .use_first_obs :
169+ msg = "Here is the first observation."
170+ screenshot_key = "screenshot_tag" if self .tag_screenshot else "screenshot"
171+ if self .tag_screenshot :
172+ msg += " A red dot on screenshots indicate the previous click action."
173+ message = self .msg_builder .user ().add_text (msg )
174+ message .add_image (image_to_png_base64_url (obs [screenshot_key ]))
175+ self .messages .append (message )
176+
213177
214178OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs (
215179 model_name = "gpt-4.1" ,
@@ -220,6 +184,14 @@ def get_action(self, obs: Any) -> float:
220184 vision_support = True ,
221185)
222186
187+ OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs (
188+ model_name = "gpt-4o-2024-08-06" ,
189+ max_total_tokens = 200_000 ,
190+ max_input_tokens = 200_000 ,
191+ max_new_tokens = 2_000 ,
192+ temperature = 0.1 ,
193+ vision_support = True ,
194+ )
223195
224196CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs (
225197 model_name = "claude-3-7-sonnet-20250219" ,
@@ -231,6 +203,103 @@ def get_action(self, obs: Any) -> float:
231203)
232204
233205
206+
207+
208+ # def get_openrouter_model(model_name: str, **open_router_args) -> OpenRouterModelArgs:
209+ # default_model_args = {
210+ # "max_total_tokens": 200_000,
211+ # "max_input_tokens": 180_000,
212+ # "max_new_tokens": 2_000,
213+ # "temperature": 0.1,
214+ # "vision_support": True,
215+ # }
216+ # merged_args = {**default_model_args, **open_router_args}
217+
218+ # return OpenRouterModelArgs(model_name=model_name, **merged_args)
219+
220+
221+ # def get_openrouter_tool_use_agent(
222+ # model_name: str,
223+ # model_args: dict = {},
224+ # use_first_obs=True,
225+ # tag_screenshot=True,
226+ # use_raw_page_output=True,
227+ # ) -> ToolUseAgentArgs:
228+ # # To Do : Check if OpenRouter endpoint specific args are working
229+ # if not supports_tool_calling(model_name):
230+ # raise ValueError(f"Model {model_name} does not support tool calling.")
231+
232+ # model_args = get_openrouter_model(model_name, **model_args)
233+
234+ # return ToolUseAgentArgs(
235+ # model_args=model_args,
236+ # use_first_obs=use_first_obs,
237+ # tag_screenshot=tag_screenshot,
238+ # use_raw_page_output=use_raw_page_output,
239+ # )
240+
241+
242+ # OPENROUTER_MODEL = get_openrouter_tool_use_agent("google/gemini-2.5-pro-preview")
243+
244+
234245AGENT_CONFIG = ToolUseAgentArgs (
235246 model_args = CLAUDE_MODEL_CONFIG ,
236247)
248+
249+ # MT_TOOL_USE_AGENT = ToolUseAgentArgs(
250+ # model_args=OPENROUTER_MODEL,
251+ # )
252+ CHATAPI_AGENT_CONFIG = ToolUseAgentArgs (
253+ model_args = OpenAIChatModelArgs (
254+ model_name = "gpt-4o-2024-11-20" ,
255+ max_total_tokens = 200_000 ,
256+ max_input_tokens = 200_000 ,
257+ max_new_tokens = 2_000 ,
258+ temperature = 0.7 ,
259+ vision_support = True ,
260+ ),
261+ )
262+
263+
264+ OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs (
265+ model_args = OpenAIChatModelArgs (model_name = "gpt-4o-2024-08-06" )
266+ )
267+
268+
269+ PROVIDER_FACTORY_MAP = {
270+ "openai" : {"chatcompletion" : OpenAIChatModelArgs , "response" : OpenAIResponseModelArgs },
271+ "openrouter" : OpenRouterModelArgs ,
272+ "vllm" : VLLMModelArgs ,
273+ "antrophic" : ClaudeResponseModelArgs ,
274+ }
275+
276+
277+ def get_tool_use_agent (
278+ api_provider : str ,
279+ model_args : "BaseModelArgs" ,
280+ tool_use_agent_args : dict = None ,
281+ api_provider_spec = None ,
282+ ) -> ToolUseAgentArgs :
283+
284+ if api_provider == "openai" :
285+ assert (
286+ api_provider_spec is not None
287+ ), "Endpoint specification is required for OpenAI provider. Choose between 'chatcompletion' and 'response'."
288+
289+ model_args_factory = (
290+ PROVIDER_FACTORY_MAP [api_provider ]
291+ if api_provider_spec is None
292+ else PROVIDER_FACTORY_MAP [api_provider ][api_provider_spec ]
293+ )
294+
295+ # Create the agent with model arguments from the factory
296+ agent = ToolUseAgentArgs (
297+ model_args = model_args_factory (** model_args ), ** (tool_use_agent_args or {})
298+ )
299+ return agent
300+
301+
302+ ## We have three providers that we want to support.
303+ # Anthropic
304+ # OpenAI
305+ # vllm (uses OpenAI API)
0 commit comments