55from typing import TYPE_CHECKING , Any
66
77import bgym
8+ import numpy as np
89from browsergym .core .observation import extract_screenshot
10+ from PIL import Image , ImageDraw
911
1012from agentlab .agents .agent_args import AgentArgs
1113from agentlab .llm .llm_utils import image_to_png_base64_url
12- from agentlab .llm .response_api import OpenAIResponseModelArgs
14+ from agentlab .llm .response_api import (
15+ ClaudeResponseModelArgs ,
16+ MessageBuilder ,
17+ OpenAIResponseModelArgs ,
18+ )
1319from agentlab .llm .tracking import cost_tracker_decorator
1420
1521if TYPE_CHECKING :
1622 from openai .types .responses import Response
1723
1824
25+ def tag_screenshot_with_action (screenshot : Image , action : str ) -> Image :
26+ """
27+ If action is a coordinate action, try to render it on the screenshot.
28+
29+ e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot
30+
31+ Args:
32+ screenshot: The screenshot to tag.
33+ action: The action to tag the screenshot with.
34+
35+ Returns:
36+ The tagged screenshot.
37+
38+ Raises:
39+ ValueError: If the action parsing fails.
40+ """
41+ if action .startswith ("mouse_click" ):
42+ try :
43+ coords = action [action .index ("(" ) + 1 : action .index (")" )].split ("," )
44+ coords = [c .strip () for c in coords ]
45+ if len (coords ) != 2 :
46+ raise ValueError (f"Invalid coordinate format: { coords } " )
47+ if coords [0 ].startswith ("x=" ):
48+ coords [0 ] = coords [0 ][2 :]
49+ if coords [1 ].startswith ("y=" ):
50+ coords [1 ] = coords [1 ][2 :]
51+ x , y = float (coords [0 ].strip ()), float (coords [1 ].strip ())
52+ draw = ImageDraw .Draw (screenshot )
53+ radius = 5
54+ draw .ellipse (
55+ (x - radius , y - radius , x + radius , y + radius ), fill = "red" , outline = "red"
56+ )
57+ except (ValueError , IndexError ) as e :
58+ logging .warning (f"Failed to parse action '{ action } ': { e } " )
59+ return screenshot
60+
61+
1962@dataclass
2063class ToolUseAgentArgs (AgentArgs ):
2164 temperature : float = 0.1
@@ -48,14 +91,18 @@ def __init__(
4891 self ,
4992 temperature : float ,
5093 model_args : OpenAIResponseModelArgs ,
94+ use_first_obs : bool = True ,
95+ tag_screenshot : bool = True ,
5196 ):
5297 self .temperature = temperature
5398 self .chat = model_args .make_model ()
5499 self .model_args = model_args
100+ self .use_first_obs = use_first_obs
101+ self .tag_screenshot = tag_screenshot
55102
56103 self .action_set = bgym .HighLevelActionSet (["coord" ], multiaction = False )
57104
58- self .tools = self .action_set .to_tool_description ()
105+ self .tools = self .action_set .to_tool_description (api = "anthropic" )
59106
60107 # self.tools.append(
61108 # {
@@ -77,87 +124,94 @@ def __init__(
77124
78125 self .llm = model_args .make_model (extra_kwargs = {"tools" : self .tools })
79126
80- self .messages = []
127+ self .messages : list [ MessageBuilder ] = []
81128
82129 def obs_preprocessor (self , obs ):
83130 page = obs .pop ("page" , None )
84131 if page is not None :
85132 obs ["screenshot" ] = extract_screenshot (page )
133+ if self .tag_screenshot :
134+ obs ["screenshot" ] = Image .fromarray (obs ["screenshot" ])
135+ obs ["screenshot" ] = tag_screenshot_with_action (
136+ obs ["screenshot" ], obs ["last_action" ]
137+ )
138+ obs ["screenshot" ] = np .array (obs ["screenshot" ])
86139 else :
87140 raise ValueError ("No page found in the observation." )
88141
89142 return obs
90143
91144 @cost_tracker_decorator
92- def get_action (self , obs : Any ) -> tuple [str , dict ]:
93-
145+ def get_action (self , obs : Any ) -> float :
94146 if len (self .messages ) == 0 :
95- system_message = {
96- "role" : "system" ,
97- "content" : "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal." ,
98- }
99- goal_object = [el for el in obs ["goal_object" ]]
100- for content in goal_object :
101- if content ["type" ] == "text" :
102- content ["type" ] = "input_text"
103- elif content ["type" ] == "image_url" :
104- content ["type" ] = "input_image"
105- goal_message = {"role" : "user" , "content" : goal_object }
106- goal_message ["content" ].append (
107- {
108- "type" : "input_image" ,
109- "image_url" : image_to_png_base64_url (obs ["screenshot" ]),
110- }
147+ system_message = MessageBuilder .system ().add_text (
148+ "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
111149 )
112150 self .messages .append (system_message )
151+
152+ goal_message = MessageBuilder .user ()
153+ for content in obs ["goal_object" ]:
154+ if content ["type" ] == "text" :
155+ goal_message .add_text (content ["text" ])
156+ elif content ["type" ] == "image_url" :
157+ goal_message .add_image (content ["image_url" ])
113158 self .messages .append (goal_message )
159+
160+ if self .use_first_obs :
161+ message = MessageBuilder .user ().add_text (
162+ "Here is the first observation. A red dot on screenshots indicate the previous click action:"
163+ )
164+ message .add_image (image_to_png_base64_url (obs ["screenshot" ]))
165+ self .messages .append (message )
114166 else :
115167 if obs ["last_action_error" ] == "" :
116- self .messages .append (
117- {
118- "type" : "function_call_output" ,
119- "call_id" : self .previous_call_id ,
120- "output" : "Function call executed, see next observation." ,
121- }
122- )
123- self .messages .append (
124- {
125- "role" : "user" ,
126- "content" : [
127- {
128- "type" : "input_image" ,
129- "image_url" : image_to_png_base64_url (obs ["screenshot" ]),
130- }
131- ],
132- }
168+ tool_message = MessageBuilder .tool ().add_image (
169+ image_to_png_base64_url (obs ["screenshot" ])
133170 )
171+ tool_message .add_tool_id (self .previous_call_id )
172+ self .messages .append (tool_message )
134173 else :
135- self .messages .append (
136- {
137- "type" : "function_call_output" ,
138- "call_id" : self .previous_call_id ,
139- "output" : f"Function call failed: { obs ['last_action_error' ]} " ,
140- }
174+ tool_message = MessageBuilder .tool ().add_text (
175+ f"Function call failed: { obs ['last_action_error' ]} "
141176 )
177+ tool_message .add_tool_id (self .previous_call_id )
178+ self .messages .append (tool_message )
142179
180+ messages = []
181+ for msg in self .messages :
182+ if isinstance (msg , MessageBuilder ):
183+ messages += msg .to_anthropic ()
184+ else :
185+ messages .append (msg )
143186 response : "Response" = self .llm (
144- messages = self . messages ,
187+ messages = messages ,
145188 temperature = self .temperature ,
146189 )
147190
148191 action = "noop()"
149192 think = ""
150- for output in response .output :
151- if output .type == "function_call" :
152- arguments = json .loads (output .arguments )
153- action = f"{ output .name } ({ ", " .join ([f"{ k } ={ v } " for k , v in arguments .items ()])} )"
154- self .previous_call_id = output .call_id
155- self .messages .append (output )
156- break
157- elif output .type == "reasoning" :
158- if len (output .summary ) > 0 :
159- think += output .summary [0 ].text + "\n "
160- self .messages .append (output )
193+ # openai
194+ # for output in response.output:
195+ # if output.type == "function_call":
196+ # arguments = json.loads(output.arguments)
197+ # action = f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})"
198+ # self.previous_call_id = output.call_id
199+ # self.messages.append(output)
200+ # break
201+ # elif output.type == "reasoning":
202+ # if len(output.summary) > 0:
203+ # think += output.summary[0].text + "\n"
204+ # self.messages.append(output)
205+
206+ # anthropic
207+ for output in response .content :
208+ if output .type == "text" :
209+ think += output .text
210+ elif output .type == "tool_use" :
211+ action = f"{ output .name } ({ ', ' .join ([f'{ k } =\" { v } \" ' if isinstance (v , str ) else f'{ k } ={ v } ' for k , v in output .input .items ()])} )"
212+ self .previous_call_id = output .id
213+
214+ self .messages .append ({"role" : "assistant" , "content" : response .content })
161215
162216 return (
163217 action ,
@@ -170,15 +224,26 @@ def get_action(self, obs: Any) -> tuple[str, dict]:
170224
171225
172226MODEL_CONFIG = OpenAIResponseModelArgs (
173- model_name = "o4-mini-2025-04-16" ,
227+ model_name = "gpt-4o" ,
228+ max_total_tokens = 200_000 ,
229+ max_input_tokens = 200_000 ,
230+ max_new_tokens = 2_000 ,
231+ temperature = 0.1 ,
232+ vision_support = True ,
233+ )
234+
235+
236+ CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs (
237+ model_name = "claude-3-7-sonnet-20250219" ,
174238 max_total_tokens = 200_000 ,
175239 max_input_tokens = 200_000 ,
176- max_new_tokens = 100_000 ,
240+ max_new_tokens = 2_000 ,
177241 temperature = 0.1 ,
178242 vision_support = True ,
179243)
180244
245+
181246AGENT_CONFIG = ToolUseAgentArgs (
182247 temperature = 0.1 ,
183- model_args = MODEL_CONFIG ,
248+ model_args = CLAUDE_MODEL_CONFIG ,
184249)
0 commit comments