Skip to content

Commit ef6f648

Browse files
committed
claude
1 parent 6604dbc commit ef6f648

File tree

3 files changed

+158
-66
lines changed

3 files changed

+158
-66
lines changed

src/agentlab/agents/tool_use_agent/agent.py

Lines changed: 124 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,60 @@
55
from typing import TYPE_CHECKING, Any
66

77
import bgym
8+
import numpy as np
89
from browsergym.core.observation import extract_screenshot
10+
from PIL import Image, ImageDraw
911

1012
from agentlab.agents.agent_args import AgentArgs
1113
from agentlab.llm.llm_utils import image_to_png_base64_url
12-
from agentlab.llm.response_api import OpenAIResponseModelArgs
14+
from agentlab.llm.response_api import (
15+
ClaudeResponseModelArgs,
16+
MessageBuilder,
17+
OpenAIResponseModelArgs,
18+
)
1319
from agentlab.llm.tracking import cost_tracker_decorator
1420

1521
if TYPE_CHECKING:
1622
from openai.types.responses import Response
1723

1824

25+
def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
26+
"""
27+
If action is a coordinate action, try to render it on the screenshot.
28+
29+
e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot
30+
31+
Args:
32+
screenshot: The screenshot to tag.
33+
action: The action to tag the screenshot with.
34+
35+
Returns:
36+
The tagged screenshot.
37+
38+
Raises:
39+
ValueError: If the action parsing fails.
40+
"""
41+
if action.startswith("mouse_click"):
42+
try:
43+
coords = action[action.index("(") + 1 : action.index(")")].split(",")
44+
coords = [c.strip() for c in coords]
45+
if len(coords) != 2:
46+
raise ValueError(f"Invalid coordinate format: {coords}")
47+
if coords[0].startswith("x="):
48+
coords[0] = coords[0][2:]
49+
if coords[1].startswith("y="):
50+
coords[1] = coords[1][2:]
51+
x, y = float(coords[0].strip()), float(coords[1].strip())
52+
draw = ImageDraw.Draw(screenshot)
53+
radius = 5
54+
draw.ellipse(
55+
(x - radius, y - radius, x + radius, y + radius), fill="red", outline="red"
56+
)
57+
except (ValueError, IndexError) as e:
58+
logging.warning(f"Failed to parse action '{action}': {e}")
59+
return screenshot
60+
61+
1962
@dataclass
2063
class ToolUseAgentArgs(AgentArgs):
2164
temperature: float = 0.1
@@ -48,14 +91,18 @@ def __init__(
4891
self,
4992
temperature: float,
5093
model_args: OpenAIResponseModelArgs,
94+
use_first_obs: bool = True,
95+
tag_screenshot: bool = True,
5196
):
5297
self.temperature = temperature
5398
self.chat = model_args.make_model()
5499
self.model_args = model_args
100+
self.use_first_obs = use_first_obs
101+
self.tag_screenshot = tag_screenshot
55102

56103
self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)
57104

58-
self.tools = self.action_set.to_tool_description()
105+
self.tools = self.action_set.to_tool_description(api="anthropic")
59106

60107
# self.tools.append(
61108
# {
@@ -77,87 +124,94 @@ def __init__(
77124

78125
self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})
79126

80-
self.messages = []
127+
self.messages: list[MessageBuilder] = []
81128

82129
def obs_preprocessor(self, obs):
83130
page = obs.pop("page", None)
84131
if page is not None:
85132
obs["screenshot"] = extract_screenshot(page)
133+
if self.tag_screenshot:
134+
obs["screenshot"] = Image.fromarray(obs["screenshot"])
135+
obs["screenshot"] = tag_screenshot_with_action(
136+
obs["screenshot"], obs["last_action"]
137+
)
138+
obs["screenshot"] = np.array(obs["screenshot"])
86139
else:
87140
raise ValueError("No page found in the observation.")
88141

89142
return obs
90143

91144
@cost_tracker_decorator
92-
def get_action(self, obs: Any) -> tuple[str, dict]:
93-
145+
def get_action(self, obs: Any) -> float:
94146
if len(self.messages) == 0:
95-
system_message = {
96-
"role": "system",
97-
"content": "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal.",
98-
}
99-
goal_object = [el for el in obs["goal_object"]]
100-
for content in goal_object:
101-
if content["type"] == "text":
102-
content["type"] = "input_text"
103-
elif content["type"] == "image_url":
104-
content["type"] = "input_image"
105-
goal_message = {"role": "user", "content": goal_object}
106-
goal_message["content"].append(
107-
{
108-
"type": "input_image",
109-
"image_url": image_to_png_base64_url(obs["screenshot"]),
110-
}
147+
system_message = MessageBuilder.system().add_text(
148+
"You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
111149
)
112150
self.messages.append(system_message)
151+
152+
goal_message = MessageBuilder.user()
153+
for content in obs["goal_object"]:
154+
if content["type"] == "text":
155+
goal_message.add_text(content["text"])
156+
elif content["type"] == "image_url":
157+
goal_message.add_image(content["image_url"])
113158
self.messages.append(goal_message)
159+
160+
if self.use_first_obs:
161+
message = MessageBuilder.user().add_text(
162+
"Here is the first observation. A red dot on screenshots indicate the previous click action:"
163+
)
164+
message.add_image(image_to_png_base64_url(obs["screenshot"]))
165+
self.messages.append(message)
114166
else:
115167
if obs["last_action_error"] == "":
116-
self.messages.append(
117-
{
118-
"type": "function_call_output",
119-
"call_id": self.previous_call_id,
120-
"output": "Function call executed, see next observation.",
121-
}
122-
)
123-
self.messages.append(
124-
{
125-
"role": "user",
126-
"content": [
127-
{
128-
"type": "input_image",
129-
"image_url": image_to_png_base64_url(obs["screenshot"]),
130-
}
131-
],
132-
}
168+
tool_message = MessageBuilder.tool().add_image(
169+
image_to_png_base64_url(obs["screenshot"])
133170
)
171+
tool_message.add_tool_id(self.previous_call_id)
172+
self.messages.append(tool_message)
134173
else:
135-
self.messages.append(
136-
{
137-
"type": "function_call_output",
138-
"call_id": self.previous_call_id,
139-
"output": f"Function call failed: {obs['last_action_error']}",
140-
}
174+
tool_message = MessageBuilder.tool().add_text(
175+
f"Function call failed: {obs['last_action_error']}"
141176
)
177+
tool_message.add_tool_id(self.previous_call_id)
178+
self.messages.append(tool_message)
142179

180+
messages = []
181+
for msg in self.messages:
182+
if isinstance(msg, MessageBuilder):
183+
messages += msg.to_anthropic()
184+
else:
185+
messages.append(msg)
143186
response: "Response" = self.llm(
144-
messages=self.messages,
187+
messages=messages,
145188
temperature=self.temperature,
146189
)
147190

148191
action = "noop()"
149192
think = ""
150-
for output in response.output:
151-
if output.type == "function_call":
152-
arguments = json.loads(output.arguments)
153-
action = f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})"
154-
self.previous_call_id = output.call_id
155-
self.messages.append(output)
156-
break
157-
elif output.type == "reasoning":
158-
if len(output.summary) > 0:
159-
think += output.summary[0].text + "\n"
160-
self.messages.append(output)
193+
# openai
194+
# for output in response.output:
195+
# if output.type == "function_call":
196+
# arguments = json.loads(output.arguments)
197+
# action = f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})"
198+
# self.previous_call_id = output.call_id
199+
# self.messages.append(output)
200+
# break
201+
# elif output.type == "reasoning":
202+
# if len(output.summary) > 0:
203+
# think += output.summary[0].text + "\n"
204+
# self.messages.append(output)
205+
206+
# anthropic
207+
for output in response.content:
208+
if output.type == "text":
209+
think += output.text
210+
elif output.type == "tool_use":
211+
action = f"{output.name}({', '.join([f'{k}=\"{v}\"' if isinstance(v, str) else f'{k}={v}' for k, v in output.input.items()])})"
212+
self.previous_call_id = output.id
213+
214+
self.messages.append({"role": "assistant", "content": response.content})
161215

162216
return (
163217
action,
@@ -170,15 +224,26 @@ def get_action(self, obs: Any) -> tuple[str, dict]:
170224

171225

172226
MODEL_CONFIG = OpenAIResponseModelArgs(
173-
model_name="o4-mini-2025-04-16",
227+
model_name="gpt-4o",
228+
max_total_tokens=200_000,
229+
max_input_tokens=200_000,
230+
max_new_tokens=2_000,
231+
temperature=0.1,
232+
vision_support=True,
233+
)
234+
235+
236+
CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs(
237+
model_name="claude-3-7-sonnet-20250219",
174238
max_total_tokens=200_000,
175239
max_input_tokens=200_000,
176-
max_new_tokens=100_000,
240+
max_new_tokens=2_000,
177241
temperature=0.1,
178242
vision_support=True,
179243
)
180244

245+
181246
AGENT_CONFIG = ToolUseAgentArgs(
182247
temperature=0.1,
183-
model_args=MODEL_CONFIG,
248+
model_args=CLAUDE_MODEL_CONFIG,
184249
)

src/agentlab/analyze/agent_xray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
560560
draw = ImageDraw.Draw(screenshot)
561561
radius = 5
562562
draw.ellipse(
563-
(x - radius, y - radius, x + radius, y + radius), fill="red", outline="red"
563+
(x - radius, y - radius, x + radius, y + radius), fill="blue", outline="blue"
564564
)
565565
except (ValueError, IndexError) as e:
566566
warning(f"Failed to parse action '{action}': {e}")

src/agentlab/llm/response_api.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,29 @@ def to_openai(self) -> List[Message]:
7272

7373
def to_anthropic(self) -> List[Message]:
7474
content = []
75+
76+
if self.role == "system":
77+
logging.warning(
78+
"In the Anthropic API, system messages should be passed as a direct input to the client."
79+
)
80+
return []
81+
7582
for item in self.content:
7683
if "text" in item:
7784
content.append({"type": "text", "text": item["text"]})
7885
elif "image" in item:
86+
img_str: str = item["image"]
87+
# make sure to get rid of the image type for anthropic
88+
# e.g. "data:image/png;base64"
89+
if img_str.startswith("data:image/png;base64,"):
90+
img_str = img_str[len("data:image/png;base64,") :]
7991
content.append(
8092
{
8193
"type": "image",
8294
"source": {
8395
"type": "base64", # currently only base64 is supported
8496
"media_type": "image/png", # currently only png is supported
85-
"data": item["image"],
97+
"data": img_str,
8698
},
8799
}
88100
)
@@ -91,11 +103,26 @@ def to_anthropic(self) -> List[Message]:
91103
if self.role == "tool":
92104
assert self.tool_call_id is not None, "Tool call ID is required for tool messages"
93105
res[0]["role"] = "user"
94-
res[0]["content"] = {
95-
"type": "tool_result",
96-
"tool_use_id": self.tool_call_id,
97-
"content": res[0]["content"],
98-
}
106+
res[0]["content"] = [
107+
{
108+
"type": "tool_result",
109+
"tool_use_id": self.tool_call_id,
110+
"content": res[0]["content"],
111+
}
112+
]
113+
return res
114+
115+
def to_markdown(self) -> str:
116+
content = []
117+
for item in self.content:
118+
if "text" in item:
119+
content.append(item["text"])
120+
elif "image" in item:
121+
content.append(f"![image]({item['image']})")
122+
res = f"{self.role}: " + "\n".join(content)
123+
if self.role == "tool":
124+
assert self.tool_call_id is not None, "Tool call ID is required for tool messages"
125+
res += f"\nTool call ID: {self.tool_call_id}"
99126
return res
100127

101128

0 commit comments

Comments
 (0)