Skip to content

Commit 73ba428

Browse files
committed
tool agent embryo
1 parent f9d7b91 commit 73ba428

File tree

3 files changed

+351
-0
lines changed

3 files changed

+351
-0
lines changed

src/agentlab/agents/tool_use_agent/__init__.py

Whitespace-only changes.
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import json
2+
import logging
3+
from copy import deepcopy as copy
4+
from dataclasses import asdict, dataclass
5+
from typing import TYPE_CHECKING, Any
6+
7+
import bgym
8+
from browsergym.core.observation import extract_screenshot
9+
10+
from agentlab.agents.agent_args import AgentArgs
11+
from agentlab.llm.llm_utils import image_to_png_base64_url
12+
from agentlab.llm.response_api import OpenAIResponseModelArgs
13+
from agentlab.llm.tracking import cost_tracker_decorator
14+
15+
if TYPE_CHECKING:
16+
from openai.types.responses import Response
17+
18+
19+
@dataclass
20+
class ToolUseAgentArgs(AgentArgs):
21+
temperature: float = 0.1
22+
model_args: OpenAIResponseModelArgs = None
23+
24+
def __post_init__(self):
25+
try:
26+
self.agent_name = f"ToolUse-{self.model_args.model_name}".replace("/", "_")
27+
except AttributeError:
28+
pass
29+
30+
def make_agent(self) -> bgym.Agent:
31+
return ToolUseAgent(
32+
temperature=self.temperature,
33+
model_args=self.model_args,
34+
)
35+
36+
def set_reproducibility_mode(self):
37+
self.temperature = 0
38+
39+
def prepare(self):
40+
return self.model_args.prepare_server()
41+
42+
def close(self):
43+
return self.model_args.close_server()
44+
45+
46+
class ToolUseAgent(bgym.Agent):
47+
def __init__(
48+
self,
49+
temperature: float,
50+
model_args: OpenAIResponseModelArgs,
51+
):
52+
self.temperature = temperature
53+
self.chat = model_args.make_model()
54+
self.model_args = model_args
55+
56+
self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)
57+
58+
self.tools = self.action_set.to_tool_description()
59+
60+
# self.tools.append(
61+
# {
62+
# "type": "function",
63+
# "name": "chain_of_thought",
64+
# "description": "A tool that allows the agent to think step by step. Every other action must ALWAYS be preceeded by a call to this tool.",
65+
# "parameters": {
66+
# "type": "object",
67+
# "properties": {
68+
# "thoughts": {
69+
# "type": "string",
70+
# "description": "The agent's reasoning process.",
71+
# },
72+
# },
73+
# "required": ["thoughts"],
74+
# },
75+
# }
76+
# )
77+
78+
self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})
79+
80+
self.messages = []
81+
82+
def obs_preprocessor(self, obs):
83+
page = obs.pop("page", None)
84+
if page is not None:
85+
obs["screenshot"] = extract_screenshot(page)
86+
else:
87+
raise ValueError("No page found in the observation.")
88+
89+
return obs
90+
91+
@cost_tracker_decorator
92+
def get_action(self, obs: Any) -> tuple[str, dict]:
93+
94+
if len(self.messages) == 0:
95+
system_message = {
96+
"role": "system",
97+
"content": "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal.",
98+
}
99+
goal_object = [el for el in obs["goal_object"]]
100+
for content in goal_object:
101+
if content["type"] == "text":
102+
content["type"] = "input_text"
103+
elif content["type"] == "image_url":
104+
content["type"] = "input_image"
105+
goal_message = {"role": "user", "content": goal_object}
106+
goal_message["content"].append(
107+
{
108+
"type": "input_image",
109+
"image_url": image_to_png_base64_url(obs["screenshot"]),
110+
}
111+
)
112+
self.messages.append(system_message)
113+
self.messages.append(goal_message)
114+
else:
115+
if obs["last_action_error"] == "":
116+
self.messages.append(
117+
{
118+
"type": "function_call_output",
119+
"call_id": self.previous_call_id,
120+
"output": "Function call executed, see next observation.",
121+
}
122+
)
123+
self.messages.append(
124+
{
125+
"role": "user",
126+
"content": [
127+
{
128+
"type": "input_image",
129+
"image_url": image_to_png_base64_url(obs["screenshot"]),
130+
}
131+
],
132+
}
133+
)
134+
else:
135+
self.messages.append(
136+
{
137+
"type": "function_call_output",
138+
"call_id": self.previous_call_id,
139+
"output": f"Function call failed: {obs['last_action_error']}",
140+
}
141+
)
142+
143+
response: "Response" = self.llm(
144+
messages=self.messages,
145+
temperature=self.temperature,
146+
)
147+
148+
action = "noop()"
149+
think = ""
150+
for output in response.output:
151+
if output.type == "function_call":
152+
arguments = json.loads(output.arguments)
153+
action = f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})"
154+
self.previous_call_id = output.call_id
155+
self.messages.append(output)
156+
break
157+
elif output.type == "reasoning":
158+
if len(output.summary) > 0:
159+
think += output.summary[0].text + "\n"
160+
self.messages.append(output)
161+
162+
return (
163+
action,
164+
bgym.AgentInfo(
165+
think=think,
166+
chat_messages=[],
167+
stats={},
168+
),
169+
)
170+
171+
172+
MODEL_CONFIG = OpenAIResponseModelArgs(
173+
model_name="o4-mini-2025-04-16",
174+
max_total_tokens=200_000,
175+
max_input_tokens=200_000,
176+
max_new_tokens=100_000,
177+
temperature=0.1,
178+
vision_support=True,
179+
)
180+
181+
AGENT_CONFIG = ToolUseAgentArgs(
182+
temperature=0.1,
183+
model_args=MODEL_CONFIG,
184+
)

src/agentlab/llm/response_api.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import logging
2+
from dataclasses import dataclass
3+
4+
import openai
5+
from openai import OpenAI
6+
7+
from .base_api import AbstractChatModel, BaseModelArgs
8+
9+
10+
class ResponseModel(AbstractChatModel):
11+
def __init__(
12+
self,
13+
model_name,
14+
api_key=None,
15+
temperature=0.5,
16+
max_tokens=100,
17+
extra_kwargs=None,
18+
):
19+
self.model_name = model_name
20+
self.api_key = api_key
21+
self.temperature = temperature
22+
self.max_tokens = max_tokens
23+
self.extra_kwargs = extra_kwargs or {}
24+
self.client = OpenAI(api_key=api_key)
25+
26+
def __call__(self, content: dict, temperature: float = None) -> dict:
27+
temperature = temperature if temperature is not None else self.temperature
28+
try:
29+
response = self.client.responses.create(
30+
model=self.model_name,
31+
input=content,
32+
# temperature=temperature,
33+
# previous_response_id=content.get("previous_response_id", None),
34+
max_output_tokens=self.max_tokens,
35+
**self.extra_kwargs,
36+
tool_choice="required",
37+
reasoning={
38+
"effort": "low",
39+
"summary": "detailed",
40+
},
41+
)
42+
return response
43+
except openai.OpenAIError as e:
44+
logging.error(f"Failed to get a response from the API: {e}")
45+
raise e
46+
47+
48+
class OpenAIResponseModel(ResponseModel):
49+
def __init__(
50+
self, model_name, api_key=None, temperature=0.5, max_tokens=100, extra_kwargs=None
51+
):
52+
super().__init__(
53+
model_name=model_name,
54+
api_key=api_key,
55+
temperature=temperature,
56+
max_tokens=max_tokens,
57+
extra_kwargs=extra_kwargs,
58+
)
59+
60+
def __call__(self, messages: list[dict], temperature: float = None) -> dict:
61+
return super().__call__(messages, temperature)
62+
# outputs = response.output
63+
# last_computer_call_id = None
64+
# answer_type = "call"
65+
# reasoning = "No reasoning"
66+
# for output in outputs:
67+
# if output.type == "reasoning":
68+
# reasoning = output.summary[0].text
69+
# elif output.type == "computer_call":
70+
# action = output.action
71+
# last_computer_call_id = output.call_id
72+
# res = response_to_text(action)
73+
# elif output.type == "message":
74+
# res = "noop()"
75+
# answer_type = "message"
76+
# else:
77+
# logging.warning(f"Unrecognized output type: {output.type}")
78+
# continue
79+
# return {
80+
# "think": reasoning,
81+
# "action": res,
82+
# "last_computer_call_id": last_computer_call_id,
83+
# "last_response_id": response.id,
84+
# "outputs": outputs,
85+
# "answer_type": answer_type,
86+
# }
87+
88+
89+
def response_to_text(action):
90+
"""
91+
Given a computer action (e.g., click, double_click, scroll, etc.),
92+
convert it to a text description.
93+
"""
94+
action_type = action.type
95+
96+
try:
97+
match action_type:
98+
99+
case "click":
100+
x, y = action.x, action.y
101+
button = action.button
102+
print(f"Action: click at ({x}, {y}) with button '{button}'")
103+
# Not handling things like middle click, etc.
104+
if button != "left" and button != "right":
105+
button = "left"
106+
return f"mouse_click({x}, {y}, button='{button}')"
107+
108+
case "scroll":
109+
x, y = action.x, action.y
110+
scroll_x, scroll_y = action.scroll_x, action.scroll_y
111+
print(
112+
f"Action: scroll at ({x}, {y}) with offsets (scroll_x={scroll_x}, scroll_y={scroll_y})"
113+
)
114+
return f"mouse_move({x}, {y})\nscroll({scroll_x}, {scroll_y})"
115+
116+
case "keypress":
117+
keys = action.keys
118+
for k in keys:
119+
print(f"Action: keypress '{k}'")
120+
# A simple mapping for common keys; expand as needed.
121+
if k.lower() == "enter":
122+
return "keyboard_press('Enter')"
123+
elif k.lower() == "space":
124+
return "keyboard_press(' ')"
125+
else:
126+
return f"keyboard_press('{k}')"
127+
128+
case "type":
129+
text = action.text
130+
print(f"Action: type text: {text}")
131+
return f"keyboard_type('{text}')"
132+
133+
case "wait":
134+
print(f"Action: wait")
135+
return "noop()"
136+
137+
case "screenshot":
138+
# Nothing to do as screenshot is taken at each turn
139+
print(f"Action: screenshot")
140+
141+
# Handle other actions here
142+
143+
case "drag":
144+
x1, y1 = action.path[0].x, action.path[0].y
145+
x2, y2 = action.path[1].x, action.path[1].y
146+
print(f"Action: drag from ({x1}, {y1}) to ({x2}, {y2})")
147+
return f"mouse_drag_and_drop({x1}, {y1}, {x2}, {y2})"
148+
149+
case _:
150+
print(f"Unrecognized action: {action}")
151+
152+
except Exception as e:
153+
print(f"Error handling action {action}: {e}")
154+
155+
156+
@dataclass
157+
class OpenAIResponseModelArgs(BaseModelArgs):
158+
"""Serializable object for instantiating a generic chat model with an OpenAI
159+
model."""
160+
161+
def make_model(self, extra_kwargs=None):
162+
return OpenAIResponseModel(
163+
model_name=self.model_name,
164+
temperature=self.temperature,
165+
max_tokens=self.max_new_tokens,
166+
extra_kwargs=extra_kwargs,
167+
)

0 commit comments

Comments
 (0)