Skip to content

Commit 26bac17

Browse files
committed
Create ClientSessionGroup for managing multiple session connections.
This abstraction concurrently manages multiple MCP session connections. Tools, resources, and prompts are aggregated across servers. Servers may be connected to or disconnected from at any point after initialization. This abstractions can handle naming collisions using a custom user-provided hook.
1 parent 1a330ac commit 26bac17

File tree

3 files changed

+486
-0
lines changed

3 files changed

+486
-0
lines changed

src/mcp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .client.session import ClientSession
2+
from .client.session_group import ClientSessionGroup
23
from .client.stdio import StdioServerParameters, stdio_client
34
from .server.session import ServerSession
45
from .server.stdio import stdio_server
@@ -63,6 +64,7 @@
6364
"ClientRequest",
6465
"ClientResult",
6566
"ClientSession",
67+
"ClientSessionGroup",
6668
"CreateMessageRequest",
6769
"CreateMessageResult",
6870
"ErrorData",

src/mcp/client/session_group.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""
2+
SessionGroup concurrently manages multiple MCP session connections.
3+
4+
Tools, resources, and prompts are aggregated across servers. Servers may
5+
be connected to or disconnected from at any point after initialization.
6+
7+
This abstractions can handle naming collisions using a custom user-provided
8+
hook.
9+
"""
10+
11+
import contextlib
12+
from collections.abc import Callable
13+
from typing import Any, TypeAlias
14+
15+
from pydantic import BaseModel
16+
17+
import mcp
18+
from mcp import types
19+
from mcp.client.stdio import StdioServerParameters
20+
from mcp.shared.exceptions import McpError
21+
22+
23+
class ClientSessionGroup:
24+
"""Client for managing connections to multiple MCP servers.
25+
26+
This class is responsible for encapsulating management of server connections.
27+
It it aggregates tools, resources, and prompts from all connected servers.
28+
29+
For auxiliary handlers, such as resource subscription, this is delegated to
30+
the client and can be accessed via the session. For example:
31+
mcp_session_group.get_session("server_name").subscribe_to_resource(...)
32+
"""
33+
34+
class _ComponentNames(BaseModel):
35+
"""Used for reverse index to find components."""
36+
37+
prompts: set[str] = set()
38+
resources: set[str] = set()
39+
tools: set[str] = set()
40+
41+
# Standard MCP components.
42+
_prompts: dict[str, types.Prompt]
43+
_resources: dict[str, types.Resource]
44+
_tools: dict[str, types.Tool]
45+
46+
# Client-server connection management.
47+
_sessions: dict[mcp.ClientSession, _ComponentNames]
48+
_tool_to_session: dict[str, mcp.ClientSession]
49+
_exit_stack: contextlib.AsyncExitStack
50+
51+
# Optional fn consuming (component_name, serverInfo) for custom names.
52+
# This is provide a means to mitigate naming conflicts across servers.
53+
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
54+
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
55+
_component_name_hook: _ComponentNameHook | None
56+
57+
def __init__(
58+
self,
59+
exit_stack: contextlib.AsyncExitStack = contextlib.AsyncExitStack(),
60+
component_name_hook: _ComponentNameHook | None = None,
61+
) -> None:
62+
"""Initializes the MCP client."""
63+
64+
self._tools = {}
65+
self._resources = {}
66+
self._prompts = {}
67+
68+
self._sessions = {}
69+
self._tool_to_session = {}
70+
self._exit_stack = exit_stack
71+
self._component_name_hook = component_name_hook
72+
73+
@property
74+
def prompts(self) -> dict[str, types.Prompt]:
75+
"""Returns the prompts as a dictionary of names to prompts."""
76+
return self._prompts
77+
78+
@property
79+
def resources(self) -> dict[str, types.Resource]:
80+
"""Returns the resources as a dictionary of names to resources."""
81+
return self._resources
82+
83+
@property
84+
def tools(self) -> dict[str, types.Tool]:
85+
"""Returns the tools as a dictionary of names to tools."""
86+
return self._tools
87+
88+
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
89+
"""Executes a tool given its name and arguments."""
90+
session = self._tool_to_session[name]
91+
return await session.call_tool(name, args)
92+
93+
def disconnect_from_server(self, session: mcp.ClientSession) -> None:
94+
"""Disconnects from a single MCP server."""
95+
96+
if session not in self._sessions:
97+
raise McpError(
98+
types.ErrorData(
99+
code=types.INVALID_PARAMS,
100+
message="Provided session is not being managed.",
101+
)
102+
)
103+
component_names = self._sessions[session]
104+
105+
# Remove prompts associated with the session.
106+
for name in component_names.prompts:
107+
del self._prompts[name]
108+
109+
# Remove resources associated with the session.
110+
for name in component_names.resources:
111+
del self._resources[name]
112+
113+
# Remove tools associated with the session.
114+
for name in component_names.tools:
115+
del self._tools[name]
116+
117+
del self._sessions[session]
118+
119+
async def connect_to_server(
120+
self,
121+
server_params: StdioServerParameters,
122+
) -> mcp.ClientSession:
123+
"""Connects to a single MCP server."""
124+
125+
# Establish server connection and create session.
126+
server_info, session = await self._establish_session(server_params)
127+
128+
# Create a reverse index so we can find all prompts, resources, and
129+
# tools belonging to this session. Used for removing components from
130+
# the session group via self.disconnect_from_server.
131+
component_names = self._ComponentNames()
132+
133+
# Temporary components dicts. We do not want to modify the aggregate
134+
# lists in case of an intermediate failure.
135+
prompts_temp: dict[str, types.Prompt] = {}
136+
resources_temp: dict[str, types.Resource] = {}
137+
tools_temp: dict[str, types.Tool] = {}
138+
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
139+
140+
# Query the server for its prompts and aggregate to list.
141+
prompts = (await session.list_prompts()).prompts
142+
for prompt in prompts:
143+
name = self._component_name(prompt.name, server_info)
144+
if name in self._prompts:
145+
raise McpError(
146+
types.ErrorData(
147+
code=types.INVALID_PARAMS,
148+
message=f"{name} already exists in group prompts.",
149+
)
150+
)
151+
prompts_temp[name] = prompt
152+
component_names.prompts.add(name)
153+
154+
# Query the server for its resources and aggregate to list.
155+
resources = (await session.list_resources()).resources
156+
for resource in resources:
157+
name = self._component_name(resource.name, server_info)
158+
if name in self._resources:
159+
raise McpError(
160+
types.ErrorData(
161+
code=types.INVALID_PARAMS,
162+
message=f"{name} already exists in group resources.",
163+
)
164+
)
165+
resources_temp[name] = resource
166+
component_names.resources.add(name)
167+
168+
# Query the server for its tools and aggregate to list.
169+
tools = (await session.list_tools()).tools
170+
for tool in tools:
171+
name = self._component_name(tool.name, server_info)
172+
if name in self._tools:
173+
raise McpError(
174+
types.ErrorData(
175+
code=types.INVALID_PARAMS,
176+
message=f"{name} already exists in group tools.",
177+
)
178+
)
179+
tools_temp[name] = tool
180+
tool_to_session_temp[name] = session
181+
component_names.tools.add(name)
182+
183+
# Aggregate components.
184+
self._sessions[session] = component_names
185+
self._prompts.update(prompts_temp)
186+
self._resources.update(resources_temp)
187+
self._tools.update(tools_temp)
188+
self._tool_to_session.update(tool_to_session_temp)
189+
190+
return session
191+
192+
async def _establish_session(
193+
self, server_params: StdioServerParameters
194+
) -> tuple[types.Implementation, mcp.ClientSession]:
195+
"""Establish a client session to an MCP server."""
196+
client = mcp.stdio_client(server_params)
197+
read, write = await self._exit_stack.enter_async_context(client)
198+
session = await self._exit_stack.enter_async_context(
199+
mcp.ClientSession(read, write)
200+
)
201+
result = await session.initialize()
202+
return result.serverInfo, session
203+
204+
def _component_name(self, name: str, server_info: types.Implementation) -> str:
205+
if self._component_name_hook:
206+
return self._component_name_hook(name, server_info)
207+
return name

0 commit comments

Comments
 (0)