Skip to content

Commit 7a682a0

Browse files
committed
new react toolcall agent, inspired by tapeagents but independent
1 parent dfbc005 commit 7a682a0

File tree

3 files changed

+226
-4
lines changed

3 files changed

+226
-4
lines changed

experiments/run_miniwob.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88

99
from agentlab.agents.generic_agent.agent_configs import GPT5_MINI_FLAGS
1010
from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs
11+
from agentlab.agents.react_toolcall_agent import AgentConfig, LLMArgs, ReactToolCallAgentArgs
1112
from agentlab.agents.tapeagent.agent import TapeAgentArgs, load_config
1213
from agentlab.backends.browser.mcp_playwright import MCPPlaywright
1314
from agentlab.backends.browser.playwright import AsyncPlaywright
1415
from agentlab.benchmarks.miniwob import MiniWobBenchmark
1516
from agentlab.experiments.study import make_study
17+
from agentlab.llm.chat_api import BaseModelArgs
1618
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
1719

1820
fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s"
@@ -32,7 +34,7 @@ def parse_args():
3234
)
3335
parser.add_argument(
3436
"--agent",
35-
choices=["tape", "generic"],
37+
choices=["tape", "generic", "react"],
3638
default="tape",
3739
help="Agent type to use (default: tape)",
3840
)
@@ -63,6 +65,11 @@ def parse_args():
6365
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-5-mini-2025-08-07"],
6466
flags=GPT5_MINI_FLAGS,
6567
)
68+
elif args.agent == "react":
69+
agent_args = ReactToolCallAgentArgs(
70+
llm_args=LLMArgs(model_name="azure/gpt-5-mini", temperature=1.0, max_total_tokens=128000),
71+
config=AgentConfig(),
72+
)
6673
else:
6774
agent_args = TapeAgentArgs(agent_name=config.name, config=config)
6875

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import json
2+
import logging
3+
import pprint
4+
from dataclasses import dataclass
5+
from functools import partial
6+
from typing import Callable
7+
8+
from litellm import completion_with_retries
9+
from litellm.types.utils import ChatCompletionMessageToolCall, Message, ModelResponse
10+
from PIL import Image
11+
from termcolor import colored
12+
13+
from agentlab.actions import FunctionCall, ToolCallAction, ToolsActionSet, ToolSpec
14+
from agentlab.agents.agent_args import AgentArgs
15+
from agentlab.llm.chat_api import BaseModelArgs
16+
from agentlab.llm.llm_utils import image_to_png_base64_url
17+
18+
logger = logging.getLogger(__name__)
19+
20+
@dataclass
21+
class Observation:
22+
data: dict
23+
24+
def to_messages(self) -> list[dict]:
25+
messages = []
26+
tool_call_id = self.data.get("tool_call_id")
27+
if self.data.get("goal_object") and not tool_call_id: # its a first observation when there are no tool_call_id, so include goal
28+
goal=self.data["goal_object"][0]["text"]
29+
messages.append({
30+
"role": "user",
31+
"content": f"## Goal:\n{goal}"
32+
})
33+
text_obs = []
34+
if self.data.get("action_result"):
35+
result=self.data["action_result"]
36+
text_obs.append(f"Action Result:\n{result}")
37+
if self.data.get("pruned_html"):
38+
html=self.data["pruned_html"]
39+
text_obs.append(f"Pruned HTML:\n{html}")
40+
if self.data.get("axtree_txt"):
41+
axtree=self.data["axtree_txt"]
42+
text_obs.append(f"Accessibility Tree:\n{axtree}")
43+
if self.data.get("last_action_error"):
44+
error = self.data['last_action_error']
45+
text_obs.append(f"Action Error:\n{error}")
46+
if text_obs:
47+
if tool_call_id:
48+
message = {
49+
"role": "tool",
50+
"tool_call_id": tool_call_id,
51+
"content": "\n\n".join(text_obs),
52+
}
53+
else:
54+
message = {
55+
"role": "user",
56+
"content": "\n\n".join(text_obs),
57+
}
58+
messages.append(message)
59+
if self.data.get("screenshot"):
60+
if isinstance(self.data["screenshot"], Image.Image):
61+
image_content_url = image_to_png_base64_url(self.data["screenshot"])
62+
messages.append({
63+
"role": "user",
64+
"content": [{"type": "image_url", "image_url": {"url": image_content_url}}],
65+
})
66+
else:
67+
raise ValueError(f"Expected Image.Image, got {type(self.data['screenshot'])}")
68+
return messages
69+
70+
@dataclass
71+
class LLMOutput:
72+
message: Message
73+
def to_messages(self) -> list[Message]:
74+
return [self.message]
75+
76+
@dataclass
77+
class SystemMessage:
78+
message: str
79+
def to_messages(self) -> list[dict]:
80+
return [{"role": "system", "content": self.message}]
81+
82+
@dataclass
83+
class UserMessage:
84+
message: str
85+
def to_messages(self) -> list[dict]:
86+
return [{"role": "user", "content": self.message}]
87+
88+
Step = LLMOutput | Observation | SystemMessage | UserMessage
89+
90+
@dataclass
91+
class AgentConfig:
92+
use_html: bool = True
93+
use_axtree: bool = False
94+
use_screenshot: bool = True
95+
max_actions: int = 10
96+
max_retry: int = 4
97+
system_prompt: str = """
98+
You are an expert AI Agent trained to assist users with complex web tasks.
99+
Your role is to understand the goal, perform actions until the goal is accomplished and respond in a helpful and accurate manner.
100+
Keep your replies brief, concise, direct and on topic. Prioritize clarity and avoid over-elaboration.
101+
Do not express emotions or opinions.
102+
"""
103+
guidance: str = """
104+
Think along the following lines:
105+
1. Summarize the last observation and describe the visible changes in the state.
106+
2. Evaluate action success, explain impact on task and next steps.
107+
3. If you see any errors in the last observation, think about it. If there is no error, just move on.
108+
4. List next steps to move towards the goal and propose next immediate action.
109+
Then produce the function call that performs the proposed action. If the task is complete, produce the final step.
110+
"""
111+
112+
class LLMArgs(BaseModelArgs):
113+
reasoning_effort: str = "low"
114+
115+
def make_model(self) -> Callable:
116+
return partial(
117+
completion_with_retries,
118+
model=self.model_name,
119+
temperature=self.temperature,
120+
max_tokens=self.max_total_tokens,
121+
max_completion_tokens=self.max_new_tokens,
122+
reasoning_effort=self.reasoning_effort,
123+
)
124+
125+
class ReactToolCallAgent:
126+
def __init__(self, action_set: ToolsActionSet, llm: Callable, config: AgentConfig):
127+
self.action_set = action_set
128+
self.history: list[Step] = [SystemMessage(message=config.system_prompt)]
129+
self.llm = llm
130+
self.config = config
131+
self.last_tool_call_id: str = ""
132+
133+
def obs_preprocessor(self, obs: dict) -> dict:
134+
if not self.config.use_html:
135+
obs.pop("pruned_html", None)
136+
if not self.config.use_axtree:
137+
obs.pop("axtree_txt", None)
138+
if not self.config.use_screenshot:
139+
obs.pop("screenshot", None)
140+
if self.last_tool_call_id:
141+
obs["tool_call_id"] = self.last_tool_call_id
142+
return obs
143+
144+
def get_action(self, obs: dict) -> tuple[ToolCallAction, dict]:
145+
prev_actions = [step for step in self.history if isinstance(step, LLMOutput)]
146+
if len(prev_actions) >= self.config.max_actions:
147+
logger.warning("Max actions reached, stopping agent.")
148+
stop_action = ToolCallAction(id="stop", function=FunctionCall(name="final_step", arguments={}))
149+
return stop_action, {}
150+
self.history.append(Observation(data=obs))
151+
steps = self.history + [UserMessage(message=self.config.guidance)]
152+
messages = [m for step in steps for m in step.to_messages()]
153+
tools = [tool.model_dump() for tool in self.action_set.actions]
154+
try:
155+
logger.info(colored(f"Prompt:\n{pprint.pformat(messages, width=120)}", "blue"))
156+
response: ModelResponse = self.llm(
157+
tools=tools,
158+
messages=messages,
159+
num_retries=self.config.max_retry,
160+
)
161+
message = response.choices[0].message # type: ignore
162+
except Exception as e:
163+
logger.exception(f"Error getting LLM response: {e}. Prompt: {messages}")
164+
raise e
165+
logger.info(colored(f"LLM response:\n{pprint.pformat(message, width=120)}", "green"))
166+
self.history.append(LLMOutput(message=message))
167+
thoughts = self.thoughts_from_message(message)
168+
action = self.action_from_message(message)
169+
170+
return action, {"think": thoughts}
171+
172+
def thoughts_from_message(self, message) -> str:
173+
thoughts = []
174+
if reasoning := message.get("reasoning_content"):
175+
logger.info(colored(f"LLM reasoning:\n{reasoning}", "yellow"))
176+
thoughts.append(reasoning)
177+
if blocks := message.get("thinking_blocks"):
178+
for block in blocks:
179+
if thinking := getattr(block, "content", None) or getattr(block, "thinking", None):
180+
logger.info(colored(f"LLM thinking block:\n{thinking}", "yellow"))
181+
thoughts.append(thinking)
182+
if message.content:
183+
logger.info(colored(f"LLM output:\n{message.content}", "cyan"))
184+
thoughts.append(message.content)
185+
return "\n\n".join(thoughts)
186+
187+
def action_from_message(self, message) -> ToolCallAction:
188+
if message.tool_calls:
189+
if len(message.tool_calls) > 1:
190+
logger.warning("Multiple tool calls found in LLM response, using the first one.")
191+
tool_call: ChatCompletionMessageToolCall = message.tool_calls[0]
192+
assert isinstance(tool_call.function.name, str)
193+
try:
194+
args = json.loads(tool_call.function.arguments)
195+
action = ToolCallAction(
196+
id=tool_call.id,
197+
function=FunctionCall(name=tool_call.function.name, arguments=args)
198+
)
199+
except json.JSONDecodeError as e:
200+
logger.exception(f"Error in json parsing of tool call arguments, {e}: {tool_call.function.arguments}")
201+
raise e
202+
203+
self.last_tool_call_id = action.id
204+
else:
205+
raise ValueError(f"No tool call found in LLM response: {message}")
206+
return action
207+
208+
209+
@dataclass
210+
class ReactToolCallAgentArgs(AgentArgs):
211+
llm_args: LLMArgs = None # type: ignore
212+
config: AgentConfig = None # type: ignore
213+
214+
def make_agent(self, actions: list[ToolSpec]) -> ReactToolCallAgent:
215+
llm = self.llm_args.make_model()
216+
action_set = ToolsActionSet(actions=actions)
217+
return ReactToolCallAgent(action_set=action_set, llm=llm, config=self.config)
218+

src/agentlab/backends/browser/env.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def reset(self, seed: int):
5050
"focused_element_bid": "none",
5151
}
5252
obs = self.task.obs_postprocess(obs)
53-
logger.info(f"Initial obs: {obs}")
5453
return obs, {}
5554

5655
def step(self, action: ToolCallAction | str) -> tuple[dict, float, bool, bool, dict]:
@@ -74,8 +73,6 @@ def step(self, action: ToolCallAction | str) -> tuple[dict, float, bool, bool, d
7473

7574
action_exec_stop = time.time()
7675
self._turns += 1
77-
logger.info(f"Obs: {observation}")
78-
7976
truncated = self._turns >= self.max_turns
8077

8178
if self.task.validate_per_step or finished or truncated:

0 commit comments

Comments
 (0)