Skip to content

Commit ffd5c5e

Browse files
Merge pull request #248 from ServiceNow/aj/tool_use_agent_chat_completion_support
Aj/tool use agent chat completion support
2 parents ce72b41 + 7d8a08c commit ffd5c5e

File tree

3 files changed

+630
-119
lines changed

3 files changed

+630
-119
lines changed

src/agentlab/agents/tool_use_agent/agent.py

Lines changed: 152 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5,60 +5,28 @@
55

66
import bgym
77
import numpy as np
8-
from browsergym.core.observation import extract_screenshot
98
from PIL import Image, ImageDraw
109

10+
from agentlab.agents import agent_utils
1111
from agentlab.agents.agent_args import AgentArgs
1212
from agentlab.llm.llm_utils import image_to_png_base64_url
1313
from agentlab.llm.response_api import (
14+
BaseModelArgs,
1415
ClaudeResponseModelArgs,
1516
MessageBuilder,
17+
OpenAIChatModelArgs,
1618
OpenAIResponseModelArgs,
19+
OpenRouterModelArgs,
1720
ResponseLLMOutput,
21+
VLLMModelArgs,
1822
)
1923
from agentlab.llm.tracking import cost_tracker_decorator
24+
from browsergym.core.observation import extract_screenshot
2025

2126
if TYPE_CHECKING:
2227
from openai.types.responses import Response
2328

2429

25-
def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
26-
"""
27-
If action is a coordinate action, try to render it on the screenshot.
28-
29-
e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot
30-
31-
Args:
32-
screenshot: The screenshot to tag.
33-
action: The action to tag the screenshot with.
34-
35-
Returns:
36-
The tagged screenshot.
37-
38-
Raises:
39-
ValueError: If the action parsing fails.
40-
"""
41-
if action.startswith("mouse_click"):
42-
try:
43-
coords = action[action.index("(") + 1 : action.index(")")].split(",")
44-
coords = [c.strip() for c in coords]
45-
if len(coords) != 2:
46-
raise ValueError(f"Invalid coordinate format: {coords}")
47-
if coords[0].startswith("x="):
48-
coords[0] = coords[0][2:]
49-
if coords[1].startswith("y="):
50-
coords[1] = coords[1][2:]
51-
x, y = float(coords[0].strip()), float(coords[1].strip())
52-
draw = ImageDraw.Draw(screenshot)
53-
radius = 5
54-
draw.ellipse(
55-
(x - radius, y - radius, x + radius, y + radius), fill="red", outline="red"
56-
)
57-
except (ValueError, IndexError) as e:
58-
logging.warning(f"Failed to parse action '{action}': {e}")
59-
return screenshot
60-
61-
6230
@dataclass
6331
class ToolUseAgentArgs(AgentArgs):
6432
model_args: OpenAIResponseModelArgs = None
@@ -97,19 +65,9 @@ def __init__(
9765
self.model_args = model_args
9866
self.use_first_obs = use_first_obs
9967
self.tag_screenshot = tag_screenshot
100-
10168
self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)
102-
10369
self.tools = self.action_set.to_tool_description(api=model_args.api)
10470

105-
# count tools tokens
106-
from agentlab.llm.llm_utils import count_tokens
107-
108-
tool_str = json.dumps(self.tools, indent=2)
109-
print(f"Tool description: {tool_str}")
110-
tool_tokens = count_tokens(tool_str, model_args.model_name)
111-
print(f"Tool tokens: {tool_tokens}")
112-
11371
self.call_ids = []
11472

11573
# self.tools.append(
@@ -131,7 +89,7 @@ def __init__(
13189
# )
13290

13391
self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})
134-
92+
self.msg_builder = model_args.get_message_builder()
13593
self.messages: list[MessageBuilder] = []
13694

13795
def obs_preprocessor(self, obs):
@@ -140,7 +98,7 @@ def obs_preprocessor(self, obs):
14098
obs["screenshot"] = extract_screenshot(page)
14199
if self.tag_screenshot:
142100
screenshot = Image.fromarray(obs["screenshot"])
143-
screenshot = tag_screenshot_with_action(screenshot, obs["last_action"])
101+
screenshot = agent_utils.tag_screenshot_with_action(screenshot, obs["last_action"])
144102
obs["screenshot_tag"] = np.array(screenshot)
145103
else:
146104
raise ValueError("No page found in the observation.")
@@ -150,56 +108,31 @@ def obs_preprocessor(self, obs):
150108
@cost_tracker_decorator
151109
def get_action(self, obs: Any) -> float:
152110
if len(self.messages) == 0:
153-
system_message = MessageBuilder.system().add_text(
154-
"You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
155-
)
156-
self.messages.append(system_message)
157-
158-
goal_message = MessageBuilder.user()
159-
for content in obs["goal_object"]:
160-
if content["type"] == "text":
161-
goal_message.add_text(content["text"])
162-
elif content["type"] == "image_url":
163-
goal_message.add_image(content["image_url"])
164-
self.messages.append(goal_message)
165-
166-
extra_info = []
167-
168-
extra_info.append(
169-
"""Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n"""
170-
)
171-
172-
self.messages.append(MessageBuilder.user().add_text("\n".join(extra_info)))
173-
174-
if self.use_first_obs:
175-
msg = "Here is the first observation."
176-
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
177-
if self.tag_screenshot:
178-
msg += " A red dot on screenshots indicate the previous click action."
179-
message = MessageBuilder.user().add_text(msg)
180-
message.add_image(image_to_png_base64_url(obs[screenshot_key]))
181-
self.messages.append(message)
111+
self.initalize_messages(obs)
182112
else:
183-
if obs["last_action_error"] == "":
113+
if obs["last_action_error"] == "": # Check No error in the last action
184114
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
185-
tool_message = MessageBuilder.tool().add_image(
115+
tool_message = self.msg_builder.tool().add_image(
186116
image_to_png_base64_url(obs[screenshot_key])
187117
)
118+
tool_message.update_last_raw_response(self.last_response)
188119
tool_message.add_tool_id(self.previous_call_id)
189120
self.messages.append(tool_message)
190121
else:
191-
tool_message = MessageBuilder.tool().add_text(
122+
tool_message = self.msg_builder.tool().add_text(
192123
f"Function call failed: {obs['last_action_error']}"
193124
)
194125
tool_message.add_tool_id(self.previous_call_id)
126+
tool_message.update_last_raw_response(self.last_response)
195127
self.messages.append(tool_message)
196128

197129
response: ResponseLLMOutput = self.llm(messages=self.messages)
198130

199131
action = response.action
200132
think = response.think
133+
self.last_response = response
201134
self.previous_call_id = response.last_computer_call_id
202-
self.messages.append(response.assistant_message)
135+
self.messages.append(response.assistant_message) # this is tool call
203136

204137
return (
205138
action,
@@ -210,6 +143,37 @@ def get_action(self, obs: Any) -> float:
210143
),
211144
)
212145

146+
def initalize_messages(self, obs: Any) -> None:
147+
system_message = self.msg_builder.system().add_text(
148+
"You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
149+
)
150+
self.messages.append(system_message)
151+
152+
goal_message = self.msg_builder.user()
153+
for content in obs["goal_object"]:
154+
if content["type"] == "text":
155+
goal_message.add_text(content["text"])
156+
elif content["type"] == "image_url":
157+
goal_message.add_image(content["image_url"])
158+
self.messages.append(goal_message)
159+
160+
extra_info = []
161+
162+
extra_info.append(
163+
"""Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n"""
164+
)
165+
166+
self.messages.append(self.msg_builder.user().add_text("\n".join(extra_info)))
167+
168+
if self.use_first_obs:
169+
msg = "Here is the first observation."
170+
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
171+
if self.tag_screenshot:
172+
msg += " A red dot on screenshots indicate the previous click action."
173+
message = self.msg_builder.user().add_text(msg)
174+
message.add_image(image_to_png_base64_url(obs[screenshot_key]))
175+
self.messages.append(message)
176+
213177

214178
OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs(
215179
model_name="gpt-4.1",
@@ -220,6 +184,14 @@ def get_action(self, obs: Any) -> float:
220184
vision_support=True,
221185
)
222186

187+
OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs(
188+
model_name="gpt-4o-2024-08-06",
189+
max_total_tokens=200_000,
190+
max_input_tokens=200_000,
191+
max_new_tokens=2_000,
192+
temperature=0.1,
193+
vision_support=True,
194+
)
223195

224196
CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs(
225197
model_name="claude-3-7-sonnet-20250219",
@@ -231,6 +203,103 @@ def get_action(self, obs: Any) -> float:
231203
)
232204

233205

206+
207+
208+
# def get_openrouter_model(model_name: str, **open_router_args) -> OpenRouterModelArgs:
209+
# default_model_args = {
210+
# "max_total_tokens": 200_000,
211+
# "max_input_tokens": 180_000,
212+
# "max_new_tokens": 2_000,
213+
# "temperature": 0.1,
214+
# "vision_support": True,
215+
# }
216+
# merged_args = {**default_model_args, **open_router_args}
217+
218+
# return OpenRouterModelArgs(model_name=model_name, **merged_args)
219+
220+
221+
# def get_openrouter_tool_use_agent(
222+
# model_name: str,
223+
# model_args: dict = {},
224+
# use_first_obs=True,
225+
# tag_screenshot=True,
226+
# use_raw_page_output=True,
227+
# ) -> ToolUseAgentArgs:
228+
# # To Do : Check if OpenRouter endpoint specific args are working
229+
# if not supports_tool_calling(model_name):
230+
# raise ValueError(f"Model {model_name} does not support tool calling.")
231+
232+
# model_args = get_openrouter_model(model_name, **model_args)
233+
234+
# return ToolUseAgentArgs(
235+
# model_args=model_args,
236+
# use_first_obs=use_first_obs,
237+
# tag_screenshot=tag_screenshot,
238+
# use_raw_page_output=use_raw_page_output,
239+
# )
240+
241+
242+
# OPENROUTER_MODEL = get_openrouter_tool_use_agent("google/gemini-2.5-pro-preview")
243+
244+
234245
AGENT_CONFIG = ToolUseAgentArgs(
235246
model_args=CLAUDE_MODEL_CONFIG,
236247
)
248+
249+
# MT_TOOL_USE_AGENT = ToolUseAgentArgs(
250+
# model_args=OPENROUTER_MODEL,
251+
# )
252+
CHATAPI_AGENT_CONFIG = ToolUseAgentArgs(
253+
model_args=OpenAIChatModelArgs(
254+
model_name="gpt-4o-2024-11-20",
255+
max_total_tokens=200_000,
256+
max_input_tokens=200_000,
257+
max_new_tokens=2_000,
258+
temperature=0.7,
259+
vision_support=True,
260+
),
261+
)
262+
263+
264+
OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs(
265+
model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06")
266+
)
267+
268+
269+
PROVIDER_FACTORY_MAP = {
270+
"openai": {"chatcompletion": OpenAIChatModelArgs, "response": OpenAIResponseModelArgs},
271+
"openrouter": OpenRouterModelArgs,
272+
"vllm": VLLMModelArgs,
273+
"antrophic": ClaudeResponseModelArgs,
274+
}
275+
276+
277+
def get_tool_use_agent(
278+
api_provider: str,
279+
model_args: "BaseModelArgs",
280+
tool_use_agent_args: dict = None,
281+
api_provider_spec=None,
282+
) -> ToolUseAgentArgs:
283+
284+
if api_provider == "openai":
285+
assert (
286+
api_provider_spec is not None
287+
), "Endpoint specification is required for OpenAI provider. Choose between 'chatcompletion' and 'response'."
288+
289+
model_args_factory = (
290+
PROVIDER_FACTORY_MAP[api_provider]
291+
if api_provider_spec is None
292+
else PROVIDER_FACTORY_MAP[api_provider][api_provider_spec]
293+
)
294+
295+
# Create the agent with model arguments from the factory
296+
agent = ToolUseAgentArgs(
297+
model_args=model_args_factory(**model_args), **(tool_use_agent_args or {})
298+
)
299+
return agent
300+
301+
302+
## We have three providers that we want to support.
303+
# Anthropic
304+
# OpenAI
305+
# vllm (uses OpenAI API)

0 commit comments

Comments
 (0)