Skip to content

Commit 4117e0a

Browse files
committed
remove tapeagents dep from backends core, fixes
1 parent a827344 commit 4117e0a

File tree

4 files changed

+236
-32
lines changed

4 files changed

+236
-32
lines changed

src/agentlab/backends/browser/base.py

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,75 @@
11
import logging
2+
from typing import Any, Callable, Literal
23

3-
from mcp.types import ImageContent, TextContent
4+
from langchain_core.utils.function_calling import convert_to_openai_tool
45
from PIL import Image
56
from pydantic import BaseModel
6-
from tapeagents.mcp import MCPEnvironment
7-
from tapeagents.tool_calling import FunctionCall, ToolCallAction, ToolSpec
87

98
logger = logging.getLogger(__name__)
109

1110

11+
class FunctionCall(BaseModel):
12+
"""
13+
A class representing a function call.
14+
15+
Attributes:
16+
name (str): The name of the function being called.
17+
arguments (Any): The arguments to be passed to the function.
18+
"""
19+
20+
name: str
21+
arguments: Any
22+
23+
24+
class FunctionSpec(BaseModel):
25+
"""
26+
A class representing the specification of a function.
27+
28+
Attributes:
29+
name (str): The name of the function.
30+
description (str): A brief description of the function.
31+
parameters (dict): A dictionary containing the parameters of the function.
32+
"""
33+
34+
name: str
35+
description: str
36+
parameters: dict
37+
38+
39+
class ToolCallAction(BaseModel):
40+
id: str = ""
41+
function: FunctionCall
42+
43+
44+
class ToolSpec(BaseModel):
45+
"""
46+
ToolSpec is a model that represents a tool specification with a type and a function.
47+
48+
Attributes:
49+
type (Literal["function"]): The type of the tool, which is always "function".
50+
function (FunctionSpec): The specification of the function.
51+
"""
52+
53+
type: Literal["function"] = "function"
54+
function: FunctionSpec
55+
56+
def description(self) -> str:
57+
return f"{self.function.name} - {self.function.description}"
58+
59+
@classmethod
60+
def from_function(cls, function: Callable):
61+
"""
62+
Creates an instance of the class by validating the model from a given function.
63+
64+
Args:
65+
function (Callable): The function to be converted and validated.
66+
67+
Returns:
68+
(ToolSpec): An instance of the class with the validated model.
69+
"""
70+
return cls.model_validate(convert_to_openai_tool(function))
71+
72+
1273
class BrowserBackend(BaseModel):
1374
def initialize(self) -> None:
1475
raise NotImplementedError
@@ -33,32 +94,3 @@ def actions(self) -> tuple[ToolSpec]:
3394

3495
def close(self) -> None:
3596
raise NotImplementedError
36-
37-
38-
class MCPBrowserBackend(BrowserBackend):
39-
config_path: str
40-
_mcp = None
41-
42-
def initialize(self) -> None:
43-
self._mcp = MCPEnvironment(config_path=self.config_path)
44-
self._mcp.initialize()
45-
46-
def step(self, action: ToolCallAction) -> dict:
47-
contents = self._call_mcp(action)
48-
text = "\n".join([c.text for c in contents if c.type == "text"])
49-
return {"pruned_html": text, "axtree_txt": text}
50-
51-
def call_tool(self, tool_name: str, arguments: dict) -> list[TextContent | ImageContent]:
52-
return self._call_mcp(
53-
ToolCallAction(function=FunctionCall(name=tool_name, arguments=arguments))
54-
)
55-
56-
def _call_mcp(self, action: ToolCallAction) -> list[TextContent | ImageContent]:
57-
tool_result = self._mcp.step(action)
58-
return tool_result.content.content
59-
60-
def actions(self) -> tuple[ToolSpec]:
61-
return self._mcp.actions()
62-
63-
def close(self) -> None:
64-
self._mcp.close()
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import asyncio
2+
import json
3+
import logging
4+
import os
5+
from contextlib import AsyncExitStack
6+
from datetime import timedelta
7+
from typing import Any
8+
9+
from mcp import ClientSession, StdioServerParameters, stdio_client
10+
from mcp import Tool as MCPTool
11+
from mcp.types import CallToolResult, ImageContent, TextContent
12+
13+
from agentlab.backends.browser.base import BrowserBackend, FunctionSpec, ToolCallAction, ToolSpec
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class MCPClient:
19+
def __init__(self, config_path: str, read_timeout_seconds: int = 10) -> None:
20+
self.servers = self.load_config(config_path)
21+
self.sessions: dict[str, ClientSession] = {}
22+
self.tools: dict[str, MCPTool] = {}
23+
self.tool_to_server: dict[str, str] = {}
24+
self.read_timeout_seconds = read_timeout_seconds
25+
self.exit_stack = AsyncExitStack()
26+
self.loop = None
27+
28+
def initialize(self):
29+
try:
30+
self.loop = asyncio.get_event_loop()
31+
except RuntimeError:
32+
self.loop = asyncio.new_event_loop()
33+
asyncio.set_event_loop(self.loop)
34+
self.loop.run_until_complete(self.start_servers())
35+
36+
async def ainitialize(self) -> None:
37+
await self.start_servers()
38+
39+
async def start_servers(self):
40+
for server_name, server_params in self.servers.items():
41+
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
42+
stdio, write = stdio_transport
43+
session = await self.exit_stack.enter_async_context(
44+
ClientSession(
45+
stdio, write, read_timeout_seconds=timedelta(seconds=self.read_timeout_seconds)
46+
)
47+
)
48+
await session.initialize()
49+
self.sessions[server_name] = session
50+
response = await session.list_tools()
51+
for tool in response.tools:
52+
if tool.name in self.tools:
53+
raise Exception(
54+
f"Tools conflict! Tool {tool.name} already provided by server '{self.tool_to_server[tool.name]}'"
55+
)
56+
self.tools[tool.name] = tool
57+
self.tool_to_server[tool.name] = server_name
58+
logger.info(
59+
f"Connected to MCP server '{server_name}' with tools: {[tool.name for tool in response.tools]}"
60+
)
61+
logger.info(f"Started {len(self.servers)} MCP servers")
62+
63+
def load_config(self, config_path) -> dict[str, StdioServerParameters]:
64+
assert os.path.exists(config_path), f"Config path {config_path} does not exist"
65+
self.config_path = config_path
66+
67+
try:
68+
with open(config_path, "r") as f:
69+
self.config = json.load(f)
70+
except json.JSONDecodeError as e:
71+
raise ValueError(f"Failed to parse {config_path}, invalid json: {e}")
72+
try:
73+
server_configs: dict[str, dict] = self.config["mcpServers"]
74+
assert isinstance(server_configs, dict), "mcpServers must be a dict"
75+
assert len(server_configs) > 0, "mcpServers dict is empty"
76+
except Exception as e:
77+
raise ValueError(f"Failed to get MCP server configs from {config_path}: {e}")
78+
79+
servers: dict[str, StdioServerParameters] = {}
80+
for server_name, server_config_dict in server_configs.items():
81+
try:
82+
server_config_dict = self.prepare_env_vars(server_config_dict)
83+
server_params = StdioServerParameters.model_validate(server_config_dict)
84+
except Exception as e:
85+
raise ValueError(f"Failed to parse server config {server_config_dict}: {e}")
86+
servers[server_name] = server_params
87+
logger.info(f"Loaded {len(servers)} MCP server configs from {config_path}")
88+
return servers
89+
90+
def prepare_env_vars(self, server_config_dict: dict) -> dict:
91+
if server_env := server_config_dict.get("env"):
92+
for env_var, env_value in server_env.items():
93+
if (
94+
env_var in os.environ and not env_value
95+
): # reuse existing env var value if not set in config
96+
logger.info(f"Set mcp server env var {env_var} from current environment")
97+
server_config_dict["env"][env_var] = os.environ[env_var]
98+
return server_config_dict
99+
100+
def call_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
101+
result = self.loop.run_until_complete(self.acall_tool(tool_name, tool_args))
102+
return result
103+
104+
async def acall_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
105+
server_name = self.check_tool_exists(tool_name)
106+
result = await self._call_tool(server_name, tool_name, tool_args)
107+
return result
108+
109+
async def _call_tool(
110+
self, server_name: str, tool_name: str, tool_args: dict[str, Any]
111+
) -> CallToolResult:
112+
try:
113+
session = self.sessions[server_name]
114+
result = await session.call_tool(tool_name, tool_args)
115+
except Exception as e:
116+
logger.exception(f"Error calling tool {tool_name}: {e}")
117+
raise e
118+
return result
119+
120+
def check_tool_exists(self, tool_name):
121+
try:
122+
server_name = self.tool_to_server[tool_name]
123+
except KeyError:
124+
raise Exception(f"Tool {tool_name} not found in any of the MCP servers")
125+
return server_name
126+
127+
def actions(self) -> tuple[ToolSpec]:
128+
return (
129+
ToolSpec(
130+
function=FunctionSpec(
131+
name=tool.name, description=tool.description or "", parameters=tool.inputSchema
132+
)
133+
)
134+
for tool in self.tools.values()
135+
)
136+
137+
async def close(self) -> None:
138+
await self.exit_stack.aclose()
139+
140+
141+
class MCPBrowserBackend(BrowserBackend):
142+
config_path: str
143+
_mcp = None
144+
145+
def initialize(self) -> None:
146+
self._mcp = MCPClient(config_path=self.config_path)
147+
self._mcp.initialize()
148+
149+
def step(self, action: ToolCallAction) -> dict:
150+
contents = self.call_tool(action.function.name, action.function.arguments)
151+
text = "\n".join([c.text for c in contents if c.type == "text"])
152+
images = [c for c in contents if c.type == "image"]
153+
return {
154+
"pruned_html": text,
155+
"axtree_txt": text,
156+
"screenshot": images[-1] if images else None,
157+
}
158+
159+
def call_tool(self, tool_name: str, arguments: dict) -> list[TextContent | ImageContent]:
160+
tool_result = self._mcp.call_tool(tool_name, arguments)
161+
if tool_result.isError:
162+
return [TextContent(text=f"Error calling tool {tool_name}: {tool_result.error}")]
163+
return tool_result.content
164+
165+
def actions(self) -> tuple[ToolSpec]:
166+
return self._mcp.actions()
167+
168+
def close(self) -> None:
169+
self._mcp.close()

src/agentlab/benchmarks/miniwob/benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class MiniWobBenchmark(AbstractBenchmark):
2323
high_level_action_set_args: ToolsActionSet = None
2424

2525
def model_post_init(self, __context: Any) -> None:
26+
self.name = f"miniwob_{self.backend.__class__.__name__.lower()}"
2627
self.env_args_list = []
2728
if self.dataset is None:
2829
self.dataset = get_miniwob_tasks()

src/agentlab/benchmarks/miniwob/task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class MiniWobTask(AbstractWebTask):
3131
]
3232

3333
def model_post_init(self, __context: Any):
34+
if self.base_url.endswith("/"):
35+
self.base_url = self.base_url[:-1]
3436
self.url = f"{self.base_url}/{self.subdomain}.html"
3537

3638
def get_setup_js(self) -> str:

0 commit comments

Comments
 (0)