Skip to content

Commit fe05d75

Browse files
Refactor: Simplify message builder methods and add support for chat completion API
1 parent ce72b41 commit fe05d75

File tree

2 files changed

+511
-125
lines changed

2 files changed

+511
-125
lines changed

src/agentlab/agents/tool_use_agent/agent.py

Lines changed: 161 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5,60 +5,26 @@
55

66
import bgym
77
import numpy as np
8-
from browsergym.core.observation import extract_screenshot
98
from PIL import Image, ImageDraw
109

10+
from agentlab.agents import agent_utils
1111
from agentlab.agents.agent_args import AgentArgs
1212
from agentlab.llm.llm_utils import image_to_png_base64_url
1313
from agentlab.llm.response_api import (
1414
ClaudeResponseModelArgs,
1515
MessageBuilder,
16+
OpenAIChatModelArgs,
1617
OpenAIResponseModelArgs,
18+
OpenRouterModelArgs,
1719
ResponseLLMOutput,
1820
)
1921
from agentlab.llm.tracking import cost_tracker_decorator
22+
from browsergym.core.observation import extract_screenshot
2023

2124
if TYPE_CHECKING:
2225
from openai.types.responses import Response
2326

2427

25-
def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
26-
"""
27-
If action is a coordinate action, try to render it on the screenshot.
28-
29-
e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot
30-
31-
Args:
32-
screenshot: The screenshot to tag.
33-
action: The action to tag the screenshot with.
34-
35-
Returns:
36-
The tagged screenshot.
37-
38-
Raises:
39-
ValueError: If the action parsing fails.
40-
"""
41-
if action.startswith("mouse_click"):
42-
try:
43-
coords = action[action.index("(") + 1 : action.index(")")].split(",")
44-
coords = [c.strip() for c in coords]
45-
if len(coords) != 2:
46-
raise ValueError(f"Invalid coordinate format: {coords}")
47-
if coords[0].startswith("x="):
48-
coords[0] = coords[0][2:]
49-
if coords[1].startswith("y="):
50-
coords[1] = coords[1][2:]
51-
x, y = float(coords[0].strip()), float(coords[1].strip())
52-
draw = ImageDraw.Draw(screenshot)
53-
radius = 5
54-
draw.ellipse(
55-
(x - radius, y - radius, x + radius, y + radius), fill="red", outline="red"
56-
)
57-
except (ValueError, IndexError) as e:
58-
logging.warning(f"Failed to parse action '{action}': {e}")
59-
return screenshot
60-
61-
6228
@dataclass
6329
class ToolUseAgentArgs(AgentArgs):
6430
model_args: OpenAIResponseModelArgs = None
@@ -97,19 +63,9 @@ def __init__(
9763
self.model_args = model_args
9864
self.use_first_obs = use_first_obs
9965
self.tag_screenshot = tag_screenshot
100-
10166
self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)
102-
10367
self.tools = self.action_set.to_tool_description(api=model_args.api)
10468

105-
# count tools tokens
106-
from agentlab.llm.llm_utils import count_tokens
107-
108-
tool_str = json.dumps(self.tools, indent=2)
109-
print(f"Tool description: {tool_str}")
110-
tool_tokens = count_tokens(tool_str, model_args.model_name)
111-
print(f"Tool tokens: {tool_tokens}")
112-
11369
self.call_ids = []
11470

11571
# self.tools.append(
@@ -131,7 +87,7 @@ def __init__(
13187
# )
13288

13389
self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})
134-
90+
self.msg_builder = model_args.get_message_builder()
13591
self.messages: list[MessageBuilder] = []
13692

13793
def obs_preprocessor(self, obs):
@@ -140,7 +96,7 @@ def obs_preprocessor(self, obs):
14096
obs["screenshot"] = extract_screenshot(page)
14197
if self.tag_screenshot:
14298
screenshot = Image.fromarray(obs["screenshot"])
143-
screenshot = tag_screenshot_with_action(screenshot, obs["last_action"])
99+
screenshot = agent_utils.tag_screenshot_with_action(screenshot, obs["last_action"])
144100
obs["screenshot_tag"] = np.array(screenshot)
145101
else:
146102
raise ValueError("No page found in the observation.")
@@ -150,56 +106,31 @@ def obs_preprocessor(self, obs):
150106
@cost_tracker_decorator
151107
def get_action(self, obs: Any) -> float:
152108
if len(self.messages) == 0:
153-
system_message = MessageBuilder.system().add_text(
154-
"You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
155-
)
156-
self.messages.append(system_message)
157-
158-
goal_message = MessageBuilder.user()
159-
for content in obs["goal_object"]:
160-
if content["type"] == "text":
161-
goal_message.add_text(content["text"])
162-
elif content["type"] == "image_url":
163-
goal_message.add_image(content["image_url"])
164-
self.messages.append(goal_message)
165-
166-
extra_info = []
167-
168-
extra_info.append(
169-
"""Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n"""
170-
)
171-
172-
self.messages.append(MessageBuilder.user().add_text("\n".join(extra_info)))
173-
174-
if self.use_first_obs:
175-
msg = "Here is the first observation."
176-
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
177-
if self.tag_screenshot:
178-
msg += " A red dot on screenshots indicate the previous click action."
179-
message = MessageBuilder.user().add_text(msg)
180-
message.add_image(image_to_png_base64_url(obs[screenshot_key]))
181-
self.messages.append(message)
109+
self.initalize_messages(obs)
182110
else:
183-
if obs["last_action_error"] == "":
111+
if obs["last_action_error"] == "": # Check No error in the last action
184112
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
185-
tool_message = MessageBuilder.tool().add_image(
113+
tool_message = self.msg_builder.tool().add_image(
186114
image_to_png_base64_url(obs[screenshot_key])
187115
)
116+
tool_message.update_last_raw_response(self.last_response)
188117
tool_message.add_tool_id(self.previous_call_id)
189118
self.messages.append(tool_message)
190119
else:
191-
tool_message = MessageBuilder.tool().add_text(
120+
tool_message = self.msg_builder.tool().add_text(
192121
f"Function call failed: {obs['last_action_error']}"
193122
)
194123
tool_message.add_tool_id(self.previous_call_id)
124+
tool_message.update_last_raw_response(self.last_response)
195125
self.messages.append(tool_message)
196126

197127
response: ResponseLLMOutput = self.llm(messages=self.messages)
198128

199129
action = response.action
200130
think = response.think
131+
self.last_response = response
201132
self.previous_call_id = response.last_computer_call_id
202-
self.messages.append(response.assistant_message)
133+
self.messages.append(response.assistant_message) # this is tool call
203134

204135
return (
205136
action,
@@ -210,6 +141,37 @@ def get_action(self, obs: Any) -> float:
210141
),
211142
)
212143

144+
def initalize_messages(self, obs: Any) -> None:
145+
system_message = self.msg_builder.system().add_text(
146+
"You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
147+
)
148+
self.messages.append(system_message)
149+
150+
goal_message = self.msg_builder.user()
151+
for content in obs["goal_object"]:
152+
if content["type"] == "text":
153+
goal_message.add_text(content["text"])
154+
elif content["type"] == "image_url":
155+
goal_message.add_image(content["image_url"])
156+
self.messages.append(goal_message)
157+
158+
extra_info = []
159+
160+
extra_info.append(
161+
"""Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n"""
162+
)
163+
164+
self.messages.append(self.msg_builder.user().add_text("\n".join(extra_info)))
165+
166+
if self.use_first_obs:
167+
msg = "Here is the first observation."
168+
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
169+
if self.tag_screenshot:
170+
msg += " A red dot on screenshots indicate the previous click action."
171+
message = self.msg_builder.user().add_text(msg)
172+
message.add_image(image_to_png_base64_url(obs[screenshot_key]))
173+
self.messages.append(message)
174+
213175

214176
OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs(
215177
model_name="gpt-4.1",
@@ -220,6 +182,14 @@ def get_action(self, obs: Any) -> float:
220182
vision_support=True,
221183
)
222184

185+
OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs(
186+
model_name="gpt-4o-2024-08-06",
187+
max_total_tokens=200_000,
188+
max_input_tokens=200_000,
189+
max_new_tokens=2_000,
190+
temperature=0.1,
191+
vision_support=True,
192+
)
223193

224194
CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs(
225195
model_name="claude-3-7-sonnet-20250219",
@@ -231,6 +201,114 @@ def get_action(self, obs: Any) -> float:
231201
)
232202

233203

204+
def supports_tool_calling(model_name: str) -> bool:
205+
"""
206+
Check if the model supports tool calling.
207+
208+
Args:
209+
model_name (str): The name of the model.
210+
211+
Returns:
212+
bool: True if the model supports tool calling, False otherwise.
213+
"""
214+
import os
215+
216+
import openai
217+
218+
client = openai.Client(
219+
api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1"
220+
)
221+
try:
222+
response = client.chat.completions.create(
223+
model=model_name,
224+
messages=[{"role": "user", "content": "Call the test tool"}],
225+
tools=[
226+
{
227+
"type": "function",
228+
"function": {
229+
"name": "dummy_tool",
230+
"description": "Just a test tool",
231+
"parameters": {
232+
"type": "object",
233+
"properties": {},
234+
},
235+
},
236+
}
237+
],
238+
tool_choice="required",
239+
)
240+
response = response.to_dict()
241+
return "tool_calls" in response["choices"][0]["message"]
242+
except Exception as e:
243+
print(f"Model '{model_name}' error: {e}")
244+
return False
245+
246+
247+
def get_openrouter_model(model_name: str, **open_router_args) -> OpenRouterModelArgs:
248+
default_model_args = {
249+
"max_total_tokens": 200_000,
250+
"max_input_tokens": 180_000,
251+
"max_new_tokens": 2_000,
252+
"temperature": 0.1,
253+
"vision_support": True,
254+
}
255+
merged_args = {**default_model_args, **open_router_args}
256+
257+
return OpenRouterModelArgs(model_name=model_name, **merged_args)
258+
259+
260+
def get_openrouter_tool_use_agent(
261+
model_name: str,
262+
model_args: dict = {},
263+
use_first_obs=True,
264+
tag_screenshot=True,
265+
use_raw_page_output=True,
266+
) -> ToolUseAgentArgs:
267+
#To Do : Check if OpenRouter endpoint specific args are working
268+
if not supports_tool_calling(model_name):
269+
raise ValueError(f"Model {model_name} does not support tool calling.")
270+
271+
model_args = get_openrouter_model(model_name, **model_args)
272+
273+
return ToolUseAgentArgs(
274+
model_args=model_args,
275+
use_first_obs=use_first_obs,
276+
tag_screenshot=tag_screenshot,
277+
use_raw_page_output=use_raw_page_output,
278+
)
279+
280+
281+
OPENROUTER_MODEL = get_openrouter_tool_use_agent("google/gemini-2.5-pro-preview")
282+
283+
234284
AGENT_CONFIG = ToolUseAgentArgs(
235285
model_args=CLAUDE_MODEL_CONFIG,
236286
)
287+
288+
MT_TOOL_USE_AGENT = ToolUseAgentArgs(
289+
model_args=OPENROUTER_MODEL,
290+
)
291+
CHATAPI_AGENT_CONFIG = ToolUseAgentArgs(
292+
model_args=OpenAIChatModelArgs(
293+
model_name="gpt-4o-2024-11-20",
294+
max_total_tokens=200_000,
295+
max_input_tokens=200_000,
296+
max_new_tokens=2_000,
297+
temperature=0.7,
298+
vision_support=True,
299+
),
300+
)
301+
302+
303+
OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs(
304+
model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06"),
305+
use_first_obs=False,
306+
tag_screenshot=False,
307+
use_raw_page_output=True,
308+
)
309+
310+
311+
## We have three providers that we want to support.
312+
# Anthropic
313+
# OpenAI
314+
# vllm (uses OpenAI API)

0 commit comments

Comments
 (0)