Skip to content

Commit 0fc43cc

Browse files
committed
unified claude and openai response apis
1 parent 54ec412 commit 0fc43cc

File tree

3 files changed

+98
-71
lines changed

3 files changed

+98
-71
lines changed

src/agentlab/agents/tool_use_agent/agent.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ClaudeResponseModelArgs,
1616
MessageBuilder,
1717
OpenAIResponseModelArgs,
18+
ResponseLLMOutput,
1819
)
1920
from agentlab.llm.tracking import cost_tracker_decorator
2021

@@ -61,8 +62,9 @@ def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
6162

6263
@dataclass
6364
class ToolUseAgentArgs(AgentArgs):
64-
temperature: float = 0.1
6565
model_args: OpenAIResponseModelArgs = None
66+
use_first_obs: bool = True
67+
tag_screenshot: bool = True
6668

6769
def __post_init__(self):
6870
try:
@@ -72,13 +74,11 @@ def __post_init__(self):
7274

7375
def make_agent(self) -> bgym.Agent:
7476
return ToolUseAgent(
75-
temperature=self.temperature,
7677
model_args=self.model_args,
78+
use_first_obs=self.use_first_obs,
79+
tag_screenshot=self.tag_screenshot,
7780
)
7881

79-
def set_reproducibility_mode(self):
80-
self.temperature = 0
81-
8282
def prepare(self):
8383
return self.model_args.prepare_server()
8484

@@ -89,20 +89,18 @@ def close(self):
8989
class ToolUseAgent(bgym.Agent):
9090
def __init__(
9191
self,
92-
temperature: float,
9392
model_args: OpenAIResponseModelArgs,
9493
use_first_obs: bool = True,
9594
tag_screenshot: bool = True,
9695
):
97-
self.temperature = temperature
9896
self.chat = model_args.make_model()
9997
self.model_args = model_args
10098
self.use_first_obs = use_first_obs
10199
self.tag_screenshot = tag_screenshot
102100

103101
self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)
104102

105-
self.tools = self.action_set.to_tool_description(api="anthropic")
103+
self.tools = self.action_set.to_tool_description(api=model_args.api)
106104

107105
# self.tools.append(
108106
# {
@@ -131,11 +129,9 @@ def obs_preprocessor(self, obs):
131129
if page is not None:
132130
obs["screenshot"] = extract_screenshot(page)
133131
if self.tag_screenshot:
134-
obs["screenshot"] = Image.fromarray(obs["screenshot"])
135-
obs["screenshot"] = tag_screenshot_with_action(
136-
obs["screenshot"], obs["last_action"]
137-
)
138-
obs["screenshot"] = np.array(obs["screenshot"])
132+
screenshot = Image.fromarray(obs["screenshot"])
133+
screenshot = tag_screenshot_with_action(screenshot, obs["last_action"])
134+
obs["screenshot_tag"] = np.array(screenshot)
139135
else:
140136
raise ValueError("No page found in the observation.")
141137

@@ -158,16 +154,25 @@ def get_action(self, obs: Any) -> float:
158154
self.messages.append(goal_message)
159155

160156
if self.use_first_obs:
161-
message = MessageBuilder.user().add_text(
162-
"Here is the first observation. A red dot on screenshots indicate the previous click action:"
163-
)
164-
message.add_image(image_to_png_base64_url(obs["screenshot"]))
157+
if self.tag_screenshot:
158+
message = MessageBuilder.user().add_text(
159+
"Here is the first observation. A red dot on screenshots indicate the previous click action:"
160+
)
161+
message.add_image(image_to_png_base64_url(obs["screenshot_tag"]))
162+
else:
163+
message = MessageBuilder.user().add_text("Here is the first observation:")
164+
message.add_image(image_to_png_base64_url(obs["screenshot"]))
165165
self.messages.append(message)
166166
else:
167167
if obs["last_action_error"] == "":
168-
tool_message = MessageBuilder.tool().add_image(
169-
image_to_png_base64_url(obs["screenshot"])
170-
)
168+
if self.tag_screenshot:
169+
tool_message = MessageBuilder.tool().add_image(
170+
image_to_png_base64_url(obs["screenshot_tag"])
171+
)
172+
else:
173+
tool_message = MessageBuilder.tool().add_image(
174+
image_to_png_base64_url(obs["screenshot"])
175+
)
171176
tool_message.add_tool_id(self.previous_call_id)
172177
self.messages.append(tool_message)
173178
else:
@@ -177,21 +182,12 @@ def get_action(self, obs: Any) -> float:
177182
tool_message.add_tool_id(self.previous_call_id)
178183
self.messages.append(tool_message)
179184

180-
messages = []
181-
for msg in self.messages:
182-
if isinstance(msg, MessageBuilder):
183-
messages += msg.to_anthropic()
184-
else:
185-
messages.append(msg)
186-
response: "Response" = self.llm(
187-
messages=messages,
188-
temperature=self.temperature,
189-
)
185+
response: ResponseLLMOutput = self.llm(messages=self.messages)
190186

191-
action = response["action"]
192-
think = response["think"]
193-
self.previous_call_id = response["last_computer_call_id"]
194-
self.messages.append(response["assistant_message"])
187+
action = response.action
188+
think = response.think
189+
self.previous_call_id = response.last_computer_call_id
190+
self.messages.append(response.assistant_message)
195191

196192
return (
197193
action,
@@ -203,8 +199,8 @@ def get_action(self, obs: Any) -> float:
203199
)
204200

205201

206-
MODEL_CONFIG = OpenAIResponseModelArgs(
207-
model_name="gpt-4o",
202+
OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs(
203+
model_name="gpt-4.1",
208204
max_total_tokens=200_000,
209205
max_input_tokens=200_000,
210206
max_new_tokens=2_000,
@@ -224,6 +220,5 @@ def get_action(self, obs: Any) -> float:
224220

225221

226222
AGENT_CONFIG = ToolUseAgentArgs(
227-
temperature=0.1,
228-
model_args=CLAUDE_MODEL_CONFIG,
223+
model_args=OPENAI_MODEL_CONFIG,
229224
)

src/agentlab/analyze/agent_xray.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,9 @@ def dict_to_markdown(d: dict):
634634
635635
dict: type = dict[str, str | list[dict[...]]]
636636
"""
637+
if not isinstance(d, dict):
638+
warning(f"Expected dict, got {type(d)}")
639+
return repr(d)
637640
if not d:
638641
return "No Data"
639642
res = ""
@@ -661,7 +664,7 @@ def update_chat_messages():
661664

662665
if isinstance(chat_messages, list) and isinstance(chat_messages[0], MessageBuilder):
663666
chat_messages = [
664-
m.to_markdown() if not isinstance(m, dict) else dict_to_markdown(m)
667+
m.to_markdown() if isinstance(m, MessageBuilder) else dict_to_markdown(m)
665668
for m in chat_messages
666669
]
667670
return "\n\n".join(chat_messages)

src/agentlab/llm/response_api.py

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515
type Message = Dict[str, Union[str, List[ContentItem]]]
1616

1717

18+
@dataclass
19+
class ResponseLLMOutput:
20+
"""Serializable object for the output of a response LLM."""
21+
22+
raw_response: Any
23+
think: str
24+
action: str
25+
last_computer_call_id: str
26+
assistant_message: Any
27+
28+
1829
class MessageBuilder:
1930
def __init__(self, role: str):
2031
self.role = role
@@ -63,13 +74,17 @@ def to_openai(self) -> List[Message]:
6374
# tool messages can only take text with openai
6475
# we need to split the first content element if it's text and use it
6576
# then open a new (user) message with the rest
66-
res[0]["tool_call_id"] = self.tool_call_id
77+
# a function_call_output dict has keys "call_id", "type" and "output"
78+
res[0]["call_id"] = self.tool_call_id
79+
res[0]["type"] = "function_call_output"
80+
res[0].pop("role", None) # make sure to remove role
6781
text_content = (
6882
content.pop(0)["text"]
6983
if "text" in content[0]
7084
else "Tool call answer in next message"
7185
)
72-
res[0]["content"] = text_content
86+
res[0]["output"] = text_content
87+
res[0].pop("content", None) # make sure to remove content
7388
res.append({"role": "user", "content": content})
7489

7590
return res
@@ -116,6 +131,8 @@ def to_anthropic(self) -> List[Message]:
116131
]
117132
return res
118133

134+
def to_chat_completion(self) -> List[Message]: ...
135+
119136
def to_markdown(self) -> str:
120137
content = []
121138
for item in self.content:
@@ -159,12 +176,12 @@ def __call__(self, messages: list[dict | MessageBuilder]) -> dict:
159176
return self._parse_response(response)
160177

161178
@abstractmethod
162-
def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
179+
def _call_api(self, messages: list[dict | MessageBuilder]) -> Any:
163180
"""Make a call to the model API and return the raw response."""
164181
pass
165182

166183
@abstractmethod
167-
def _parse_response(self, response: dict) -> dict:
184+
def _parse_response(self, response: Any) -> ResponseLLMOutput:
168185
"""Parse the raw response from the model API and return a structured response."""
169186
pass
170187

@@ -187,11 +204,17 @@ def __init__(
187204
)
188205
self.client = OpenAI(api_key=api_key)
189206

190-
def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
207+
def _call_api(self, messages: list[Any | MessageBuilder]) -> dict:
208+
input = []
209+
for msg in messages:
210+
if isinstance(msg, MessageBuilder):
211+
input += msg.to_openai()
212+
else:
213+
input.append(msg)
191214
try:
192215
response = self.client.responses.create(
193216
model=self.model_name,
194-
input=messages,
217+
input=input,
195218
temperature=self.temperature,
196219
# previous_response_id=content.get("previous_response_id", None),
197220
max_output_tokens=self.max_tokens,
@@ -208,27 +231,25 @@ def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
208231
raise e
209232

210233
def _parse_response(self, response: dict) -> dict:
211-
result = {
212-
"raw_response": response,
213-
"think": "",
214-
"action": "noop()",
215-
"last_computer_call_id": None,
216-
"assistant_message": {
217-
"role": "assistant",
218-
"content": response.output,
219-
},
220-
}
234+
result = ResponseLLMOutput(
235+
raw_response=response,
236+
think="",
237+
action="noop()",
238+
last_computer_call_id=None,
239+
assistant_message=None,
240+
)
221241
for output in response.output:
222242
if output.type == "function_call":
223243
arguments = json.loads(output.arguments)
224-
result["action"] = (
244+
result.action = (
225245
f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})"
226246
)
227-
result["last_computer_call_id"] = output.call_id
247+
result.last_computer_call_id = output.call_id
248+
result.assistant_message = output
228249
break
229250
elif output.type == "reasoning":
230251
if len(output.summary) > 0:
231-
result["think"] += output.summary[0].text + "\n"
252+
result.think += output.summary[0].text + "\n"
232253
return result
233254

234255

@@ -251,10 +272,16 @@ def __init__(
251272
self.client = Anthropic(api_key=api_key)
252273

253274
def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
275+
input = []
276+
for msg in messages:
277+
if isinstance(msg, MessageBuilder):
278+
input += msg.to_anthropic()
279+
else:
280+
input.append(msg)
254281
try:
255282
response = self.client.messages.create(
256283
model=self.model_name,
257-
messages=messages,
284+
messages=input,
258285
temperature=self.temperature,
259286
max_tokens=self.max_tokens,
260287
**self.extra_kwargs,
@@ -265,24 +292,22 @@ def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
265292
raise e
266293

267294
def _parse_response(self, response: dict) -> dict:
268-
result = {
269-
"raw_response": response,
270-
"think": "",
271-
"action": "noop()",
272-
"last_computer_call_id": None,
273-
"assistant_message": {
295+
result = ResponseLLMOutput(
296+
raw_response=response,
297+
think="",
298+
action="noop()",
299+
last_computer_call_id=None,
300+
assistant_message={
274301
"role": "assistant",
275302
"content": response.content,
276303
},
277-
}
304+
)
278305
for output in response.content:
279306
if output.type == "tool_use":
280-
result["action"] = (
281-
f"{output.name}({', '.join([f'{k}=\"{v}\"' if isinstance(v, str) else f'{k}={v}' for k, v in output.input.items()])})"
282-
)
283-
result["last_computer_call_id"] = output.id
307+
result.action = f"{output.name}({', '.join([f'{k}=\"{v}\"' if isinstance(v, str) else f'{k}={v}' for k, v in output.input.items()])})"
308+
result.last_computer_call_id = output.id
284309
elif output.type == "text":
285-
result["think"] += output.text
310+
result.think += output.text
286311
return result
287312

288313

@@ -358,6 +383,8 @@ class OpenAIResponseModelArgs(BaseModelArgs):
358383
"""Serializable object for instantiating a generic chat model with an OpenAI
359384
model."""
360385

386+
api = "openai"
387+
361388
def make_model(self, extra_kwargs=None):
362389
return OpenAIResponseModel(
363390
model_name=self.model_name,
@@ -372,6 +399,8 @@ class ClaudeResponseModelArgs(BaseModelArgs):
372399
"""Serializable object for instantiating a generic chat model with an OpenAI
373400
model."""
374401

402+
api = "anthropic"
403+
375404
def make_model(self, extra_kwargs=None):
376405
return ClaudeResponseModel(
377406
model_name=self.model_name,

0 commit comments

Comments
 (0)