Skip to content

Commit 2e4032e

Browse files
DouweMKludex
andauthored
Set MCPServer id and tool_prefix in load_mcp_servers (#3052)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 3b8ff2c commit 2e4032e

File tree

4 files changed

+125
-96
lines changed

4 files changed

+125
-96
lines changed

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ async def client_streams(
167167
def id(self) -> str | None:
168168
return self._id
169169

170+
@id.setter
171+
def id(self, value: str | None):
172+
self._id = value
173+
170174
@property
171175
def label(self) -> str:
172176
if self.id:
@@ -414,6 +418,9 @@ def _get_content(
414418
else:
415419
assert_never(resource)
416420

421+
def __eq__(self, value: object, /) -> bool:
422+
return isinstance(value, MCPServer) and self.id == value.id and self.tool_prefix == value.tool_prefix
423+
417424

418425
class MCPServerStdio(MCPServer):
419426
"""Runs an MCP server in a subprocess and communicates with it over stdin/stdout.
@@ -568,10 +575,10 @@ def __repr__(self) -> str:
568575
return f'{self.__class__.__name__}({", ".join(repr_args)})'
569576

570577
def __eq__(self, value: object, /) -> bool:
571-
if not isinstance(value, MCPServerStdio):
572-
return False # pragma: no cover
573578
return (
574-
self.command == value.command
579+
super().__eq__(value)
580+
and isinstance(value, MCPServerStdio)
581+
and self.command == value.command
575582
and self.args == value.args
576583
and self.env == value.env
577584
and self.cwd == value.cwd
@@ -809,9 +816,7 @@ def _transport_client(self):
809816
return sse_client # pragma: no cover
810817

811818
def __eq__(self, value: object, /) -> bool:
812-
if not isinstance(value, MCPServerSSE):
813-
return False # pragma: no cover
814-
return self.url == value.url
819+
return super().__eq__(value) and isinstance(value, MCPServerSSE) and self.url == value.url
815820

816821

817822
@deprecated('The `MCPServerHTTP` class is deprecated, use `MCPServerSSE` instead.')
@@ -885,9 +890,7 @@ def _transport_client(self):
885890
return streamablehttp_client # pragma: no cover
886891

887892
def __eq__(self, value: object, /) -> bool:
888-
if not isinstance(value, MCPServerStreamableHTTP):
889-
return False # pragma: no cover
890-
return self.url == value.url
893+
return super().__eq__(value) and isinstance(value, MCPServerStreamableHTTP) and self.url == value.url
891894

892895

893896
ToolResult = (
@@ -964,4 +967,11 @@ def load_mcp_servers(config_path: str | Path) -> list[MCPServerStdio | MCPServer
964967
raise FileNotFoundError(f'Config file {config_path} not found')
965968

966969
config = MCPServerConfig.model_validate_json(config_path.read_bytes())
967-
return list(config.mcp_servers.values())
970+
971+
servers: list[MCPServerStdio | MCPServerStreamableHTTP | MCPServerSSE] = []
972+
for name, server in config.mcp_servers.items():
973+
server.id = name
974+
server.tool_prefix = name
975+
servers.append(server)
976+
977+
return servers

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ dev = [
8787
"anyio>=4.5.0",
8888
"asgi-lifespan>=2.1.0",
8989
"devtools>=0.12.2",
90-
"coverage[toml]>=7.10.3",
90+
"coverage[toml]>=7.10.7",
9191
"dirty-equals>=0.9.0",
9292
"duckduckgo-search>=7.0.0",
9393
"inline-snapshot>=0.19.3",

tests/test_mcp.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,13 +1464,18 @@ def test_load_mcp_servers(tmp_path: Path):
14641464
config = tmp_path / 'mcp.json'
14651465

14661466
config.write_text('{"mcpServers": {"potato": {"url": "https://example.com/mcp"}}}')
1467-
assert load_mcp_servers(config) == snapshot([MCPServerStreamableHTTP(url='https://example.com/mcp')])
1467+
server = load_mcp_servers(config)[0]
1468+
assert server == MCPServerStreamableHTTP(url='https://example.com/mcp', id='potato', tool_prefix='potato')
14681469

14691470
config.write_text('{"mcpServers": {"potato": {"command": "python", "args": ["-m", "tests.mcp_server"]}}}')
1470-
assert load_mcp_servers(config) == snapshot([MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'])])
1471+
server = load_mcp_servers(config)[0]
1472+
assert server == MCPServerStdio(
1473+
command='python', args=['-m', 'tests.mcp_server'], id='potato', tool_prefix='potato'
1474+
)
14711475

14721476
config.write_text('{"mcpServers": {"potato": {"url": "https://example.com/sse"}}}')
1473-
assert load_mcp_servers(config) == snapshot([MCPServerSSE(url='https://example.com/sse')])
1477+
server = load_mcp_servers(config)[0]
1478+
assert server == MCPServerSSE(url='https://example.com/sse', id='potato', tool_prefix='potato')
14741479

14751480
with pytest.raises(FileNotFoundError):
14761481
load_mcp_servers(tmp_path / 'does_not_exist.json')

0 commit comments

Comments
 (0)