Skip to content

Commit b8e5c3a

Browse files
committed
simpler tool call object
1 parent cc23893 commit b8e5c3a

File tree

8 files changed

+34
-52
lines changed

8 files changed

+34
-52
lines changed

src/agentlab/actions.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import json
22
import logging
3-
from typing import Any, Callable, Literal
3+
from typing import Callable, Literal
4+
from uuid import uuid4
45

56
from bgym import AbstractActionSet
67
from langchain_core.utils.function_calling import convert_to_openai_tool
7-
from pydantic import BaseModel
8+
from pydantic import BaseModel, Field
89

910
from agentlab.llm.llm_utils import parse_html_tags_raise
1011

@@ -26,22 +27,11 @@ class FunctionSpec(BaseModel):
2627
parameters: dict
2728

2829

29-
class FunctionCall(BaseModel):
30-
"""
31-
A class representing a function call.
32-
33-
Attributes:
34-
name (str): The name of the function being called.
35-
arguments (Any): The arguments to be passed to the function.
36-
"""
3730

31+
class ToolCall(BaseModel):
32+
id: str = Field(default_factory=lambda: uuid4().hex)
3833
name: str
39-
arguments: Any
40-
41-
42-
class ToolCallAction(BaseModel):
43-
id: str = ""
44-
function: FunctionCall
34+
arguments: dict = Field(default_factory=dict)
4535

4636
def llm_view(self, **kwargs) -> str:
4737
return self.model_dump_json(indent=2)
@@ -111,7 +101,7 @@ def example_action(self, abstract: bool) -> str:
111101
}"""
112102

113103
@classmethod
114-
def parse_action(cls, llm_output: str) -> ToolCallAction:
104+
def parse_action(cls, llm_output: str) -> ToolCall:
115105
logger.info(f"Parsing action: {llm_output}")
116106
if "<action>" in llm_output:
117107
content_dict, valid, retry_message = parse_html_tags_raise(llm_output, keys=["action"])
@@ -124,7 +114,7 @@ def parse_action(cls, llm_output: str) -> ToolCallAction:
124114
action_dict = json.loads(action_str)
125115
except json.JSONDecodeError:
126116
raise ValueError(f"Failed to parse action: {action_str}")
127-
return ToolCallAction(function=FunctionCall(name=action_dict["name"], arguments=action_dict["arguments"]))
117+
return ToolCall(name=action_dict["name"], arguments=action_dict["arguments"])
128118

129119
def to_python_code(self, action) -> str:
130120
return action

src/agentlab/agents/react_toolcall_agent.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from PIL import Image
1111
from termcolor import colored
1212

13-
from agentlab.actions import FunctionCall, ToolCallAction, ToolsActionSet, ToolSpec
13+
from agentlab.actions import ToolCall, ToolsActionSet, ToolSpec
1414
from agentlab.agents.agent_args import AgentArgs
1515
from agentlab.llm.chat_api import BaseModelArgs
1616
from agentlab.llm.llm_utils import image_to_png_base64_url
@@ -112,15 +112,13 @@ def obs_to_messages(self, obs: dict) -> list[dict]:
112112
)
113113
return messages
114114

115-
def get_action(self, obs: dict) -> tuple[ToolCallAction, dict]:
115+
def get_action(self, obs: dict) -> tuple[ToolCall, dict]:
116116
actions_count = len(
117117
[msg for msg in self.history if isinstance(msg, Message) and msg.tool_calls]
118118
)
119119
if actions_count >= self.config.max_actions:
120120
logger.warning("Max actions reached, stopping agent.")
121-
stop_action = ToolCallAction(
122-
id="stop", function=FunctionCall(name="final_step", arguments={})
123-
)
121+
stop_action = ToolCall(name="final_step")
124122
return stop_action, {}
125123
self.history += self.obs_to_messages(self.obs_preprocessor(obs))
126124
tools = [tool.model_dump() for tool in self.action_set.actions]
@@ -155,15 +153,14 @@ def thoughts_from_message(self, message) -> str:
155153
thoughts.append(message.content)
156154
return "\n\n".join(thoughts)
157155

158-
def action_from_message(self, message) -> ToolCallAction:
156+
def action_from_message(self, message) -> ToolCall:
159157
if message.tool_calls:
160158
if len(message.tool_calls) > 1:
161159
logger.warning("Multiple tool calls found in LLM response, using the first one.")
162160
tool_call = message.tool_calls[0]
161+
name = tool_call.function.name
163162
args = json.loads(tool_call.function.arguments)
164-
action = ToolCallAction(
165-
id=tool_call.id, function=FunctionCall(name=tool_call.function.name, arguments=args)
166-
)
163+
action = ToolCall(id=tool_call.id, name=name, arguments=args)
167164
self.last_tool_call_id = action.id
168165
logger.info(f"Parsed tool call action: {action}")
169166
else:

src/agentlab/backends/browser/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from agentlab.actions import FunctionCall, ToolCallAction, ToolSpec
21
from agentlab.backends.browser.base import BrowserBackend
32
from agentlab.backends.browser.env import BrowserEnv, BrowserEnvArgs
43
from agentlab.backends.browser.mcp import MCPBrowserBackend, MCPClient
@@ -7,9 +6,6 @@
76

87
__all__ = [
98
"BrowserBackend",
10-
"FunctionCall",
11-
"ToolCallAction",
12-
"ToolSpec",
139
"BrowserEnv",
1410
"BrowserEnvArgs",
1511
"MCPBrowserBackend",

src/agentlab/backends/browser/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from PIL import Image
55
from pydantic import BaseModel
66

7-
from agentlab.actions import ToolCallAction, ToolSpec
7+
from agentlab.actions import ToolCall, ToolSpec
88

99
logger = logging.getLogger(__name__)
1010

@@ -35,7 +35,7 @@ def page_axtree(self) -> str:
3535
pass
3636

3737
@abstractmethod
38-
def step(self, action: ToolCallAction) -> str:
38+
def step(self, action: ToolCall) -> dict:
3939
pass
4040

4141
@abstractmethod

src/agentlab/backends/browser/env.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from pathlib import Path
55

6-
from agentlab.actions import ToolCallAction, ToolsActionSet, ToolSpec
6+
from agentlab.actions import ToolCall, ToolsActionSet, ToolSpec
77
from agentlab.backends.browser.base import BrowserBackend
88
from agentlab.benchmarks.abstract_env import AbstractEnv, AbstractEnvArgs
99
from agentlab.benchmarks.web_task import AbstractWebTask
@@ -52,13 +52,13 @@ def reset(self, seed: int):
5252
obs = self.task.obs_postprocess(obs)
5353
return obs, {}
5454

55-
def step(self, action: ToolCallAction | str) -> tuple[dict, float, bool, bool, dict]:
55+
def step(self, action: ToolCall | str) -> tuple[dict, float, bool, bool, dict]:
5656
if isinstance(action, str):
5757
action = ToolsActionSet.parse_action(action)
5858
logger.info(f"BrowserEnv.step() called with action {action}")
5959

6060
action_exec_start = time.time()
61-
finished = action.function.name == "final_step"
61+
finished = action.name == "final_step"
6262
if finished:
6363
observation = {
6464
"goal_object": [{"type": "text", "text": self.goal}],
@@ -91,7 +91,7 @@ def step(self, action: ToolCallAction | str) -> tuple[dict, float, bool, bool, d
9191
logger.info(f"Action result in observation: {observation}")
9292
return observation, reward, finished, truncated, env_info
9393

94-
def _step(self, action: ToolCallAction) -> dict:
94+
def _step(self, action: ToolCall) -> dict:
9595
obs_dict = self.backend.step(action)
9696
if "goal_object" not in obs_dict:
9797
obs_dict["goal_object"] = [{"type": "text", "text": self.goal}]
@@ -101,7 +101,7 @@ def _step(self, action: ToolCallAction) -> dict:
101101
obs_dict["focused_element_bid"] = "none"
102102
return obs_dict
103103

104-
def validate_task(self, action: ToolCallAction, observation: dict) -> tuple[float, dict]:
104+
def validate_task(self, action: ToolCall, observation: dict) -> tuple[float, dict]:
105105
validate_js = self.task.get_step_validate_js()
106106
validate_result = self.backend.run_js(validate_js)
107107
reward, other = self.task.parse_validation_result(validate_result)

src/agentlab/backends/browser/mcp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mcp import Tool as MCPTool
1111
from mcp.types import CallToolResult, ImageContent, TextContent
1212

13-
from agentlab.actions import FunctionSpec, ToolCallAction, ToolSpec
13+
from agentlab.actions import FunctionSpec, ToolCall, ToolSpec
1414
from agentlab.backends.browser.base import BrowserBackend
1515

1616
logger = logging.getLogger(__name__)
@@ -150,13 +150,12 @@ def initialize(self) -> None:
150150
self._mcp = MCPClient(config_path=self.config_path)
151151
self._mcp.initialize()
152152

153-
def step(self, action: ToolCallAction) -> dict:
154-
contents = self.call_tool(action.function.name, action.function.arguments)
153+
def step(self, action: ToolCall) -> dict:
154+
contents = self.call_tool(action.name, action.arguments)
155155
text = "\n".join([c.text for c in contents if c.type == "text"])
156156
images = [c for c in contents if c.type == "image"]
157157
return {
158-
"pruned_html": text,
159-
"axtree_txt": text,
158+
"text": text,
160159
"screenshot": images[-1] if images else None,
161160
}
162161

src/agentlab/backends/browser/mcp_playwright.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from PIL import Image
66

7-
from agentlab.actions import ToolCallAction
7+
from agentlab.actions import ToolCall
88
from agentlab.backends.browser.mcp import MCPBrowserBackend
99

1010
logger = logging.getLogger(__name__)
@@ -27,8 +27,8 @@ def run_js(self, js: str):
2727
raise e
2828
return result_str
2929

30-
def step(self, action: ToolCallAction) -> dict:
31-
contents = self.call_tool(action.function.name, action.function.arguments)
30+
def step(self, action: ToolCall) -> dict:
31+
contents = self.call_tool(action.name, action.arguments)
3232
logger.info(f"Step result has {len(contents)} contents")
3333
tool_result = "\n".join(
3434
[c.text for c in contents if c.type == "text" and "# Ran Playwright code" not in c.text]

src/agentlab/backends/browser/playwright.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from PIL import Image
77
from playwright.async_api import Browser, Page, async_playwright
88

9-
from agentlab.actions import ToolCallAction, ToolSpec
9+
from agentlab.actions import ToolCall, ToolSpec
1010
from agentlab.backends.browser.base import BrowserBackend
1111

1212
logger = logging.getLogger(__name__)
@@ -106,13 +106,13 @@ def page_axtree(self):
106106
flat_axtree = flatten_axtree(axtree)
107107
return flat_axtree
108108

109-
def step(self, action: ToolCallAction):
110-
fn = self._actions[action.function.name]
109+
def step(self, action: ToolCall):
110+
fn = self._actions[action.name]
111111
try:
112-
action_result = self._loop.run_until_complete(fn(**action.function.arguments))
112+
action_result = self._loop.run_until_complete(fn(**action.arguments))
113113
except Exception as e:
114-
logger.error(f"Error executing action {action.function.name}: {e}")
115-
action_result = f"Error executing action {action.function.name}: {e}"
114+
action_result = f"Error executing action {action.name}: {e}"
115+
logger.error(action_result)
116116
html = self.page_html()
117117
screenshot = self.page_screenshot()
118118
axtree = self.page_axtree()

0 commit comments

Comments
 (0)