Skip to content

Commit b6c6275

Browse files
committed
Add support for async context management to ClientSessionGroup
This changes enables context management for setting up and tearing down async exit stacks durring server connection and disconnection respectively. Documentation has been added to show an example use case that demonstrates how `ClientSessionGroup` can be used with `async with`.
1 parent ef3feb2 commit b6c6275

File tree

2 files changed

+182
-109
lines changed

2 files changed

+182
-109
lines changed

src/mcp/client/session_group.py

Lines changed: 178 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
import contextlib
1212
from collections.abc import Callable
1313
from datetime import timedelta
14+
from types import TracebackType
1415
from typing import Any, TypeAlias
1516

1617
from pydantic import BaseModel
18+
from typing_extensions import Self
1719

1820
import mcp
1921
from mcp import types
@@ -72,6 +74,14 @@ class ClientSessionGroup:
7274
For auxiliary handlers, such as resource subscription, this is delegated to
7375
the client and can be accessed via the session. For example:
7476
mcp_session_group.get_session("server_name").subscribe_to_resource(...)
77+
78+
Example Usage:
79+
name_fn = lambda name, server_info: f"{(server_info.name)}.{name}"
80+
with async ClientSessionGroup(component_name_hook=name_fn) as group:
81+
for server_params in server_params:
82+
group.connect_to_server(server_param)
83+
...
84+
7585
"""
7686

7787
class _ComponentNames(BaseModel):
@@ -90,6 +100,7 @@ class _ComponentNames(BaseModel):
90100
_sessions: dict[mcp.ClientSession, _ComponentNames]
91101
_tool_to_session: dict[str, mcp.ClientSession]
92102
_exit_stack: contextlib.AsyncExitStack
103+
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
93104

94105
# Optional fn consuming (component_name, serverInfo) for custom names.
95106
# This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +110,7 @@ class _ComponentNames(BaseModel):
99110

100111
def __init__(
101112
self,
102-
exit_stack: contextlib.AsyncExitStack = contextlib.AsyncExitStack(),
113+
exit_stack: contextlib.AsyncExitStack | None = None,
103114
component_name_hook: _ComponentNameHook | None = None,
104115
) -> None:
105116
"""Initializes the MCP client."""
@@ -110,9 +121,32 @@ def __init__(
110121

111122
self._sessions = {}
112123
self._tool_to_session = {}
113-
self._exit_stack = exit_stack
124+
self._main_exit_stack = exit_stack or contextlib.AsyncExitStack()
125+
self._session_exit_stacks = {}
114126
self._component_name_hook = component_name_hook
115127

128+
async def __aenter__(self) -> Self:
129+
# If ClientSessionGroup itself is managing the lifecycle of _main_exit_stack
130+
# (i.e., it created it), it should enter it.
131+
# If _main_exit_stack was passed in, it's assumed the caller manages
132+
# its entry/exit.
133+
# For simplicity and consistency with how AsyncExitStack is often used when
134+
# provided as a dependency, we might not need to enter it here if it's
135+
# managed externally. However, if this class is the primary owner, entering it
136+
# ensures its 'aclose' is called even if passed in. Let's assume the
137+
# passed-in stack is already entered by the caller if needed.
138+
# For now, we just return self as the main stack's lifecycle is tied to aclose.
139+
return self
140+
141+
async def __aexit__(
142+
self,
143+
_exc_type: type[BaseException] | None,
144+
_exc_val: BaseException | None,
145+
_exc_tb: TracebackType | None,
146+
) -> bool | None:
147+
await self._main_exit_stack.aclose()
148+
return None # Do not suppress exceptio
149+
116150
@property
117151
def prompts(self) -> dict[str, types.Prompt]:
118152
"""Returns the prompts as a dictionary of names to prompts."""
@@ -133,31 +167,42 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu
133167
session = self._tool_to_session[name]
134168
return await session.call_tool(name, args)
135169

136-
def disconnect_from_server(self, session: mcp.ClientSession) -> None:
170+
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
137171
"""Disconnects from a single MCP server."""
138172

139-
if session not in self._sessions:
173+
session_known_for_components = session in self._sessions
174+
session_known_for_stack = session in self._session_exit_stacks
175+
176+
if not session_known_for_components and not session_known_for_stack:
140177
raise McpError(
141178
types.ErrorData(
142179
code=types.INVALID_PARAMS,
143-
message="Provided session is not being managed.",
180+
message="Provided session is not managed or already disconnected.",
144181
)
145182
)
146-
component_names = self._sessions[session]
147183

148-
# Remove prompts associated with the session.
149-
for name in component_names.prompts:
150-
del self._prompts[name]
151-
152-
# Remove resources associated with the session.
153-
for name in component_names.resources:
154-
del self._resources[name]
155-
156-
# Remove tools associated with the session.
157-
for name in component_names.tools:
158-
del self._tools[name]
159-
160-
del self._sessions[session]
184+
if session_known_for_components:
185+
component_names = self._sessions.pop(session) # Pop from _sessions tracking
186+
187+
# Remove prompts associated with the session.
188+
for name in component_names.prompts:
189+
if name in self._prompts:
190+
del self._prompts[name]
191+
# Remove resources associated with the session.
192+
for name in component_names.resources:
193+
if name in self._resources:
194+
del self._resources[name]
195+
# Remove tools associated with the session.
196+
for name in component_names.tools:
197+
if name in self._tools:
198+
del self._tools[name]
199+
if name in self._tool_to_session:
200+
del self._tool_to_session[name]
201+
202+
# Clean up the session's resources via its dedicated exit stack
203+
if session_known_for_stack:
204+
session_stack_to_close = self._session_exit_stacks.pop(session)
205+
await session_stack_to_close.aclose()
161206

162207
async def connect_to_server(
163208
self,
@@ -168,102 +213,130 @@ async def connect_to_server(
168213
# Establish server connection and create session.
169214
server_info, session = await self._establish_session(server_params)
170215

171-
# Create a reverse index so we can find all prompts, resources, and
172-
# tools belonging to this session. Used for removing components from
173-
# the session group via self.disconnect_from_server.
174-
component_names = self._ComponentNames()
175-
176-
# Temporary components dicts. We do not want to modify the aggregate
177-
# lists in case of an intermediate failure.
178-
prompts_temp: dict[str, types.Prompt] = {}
179-
resources_temp: dict[str, types.Resource] = {}
180-
tools_temp: dict[str, types.Tool] = {}
181-
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
182-
183-
# Query the server for its prompts and aggregate to list.
184-
prompts = (await session.list_prompts()).prompts
185-
for prompt in prompts:
186-
name = self._component_name(prompt.name, server_info)
187-
if name in self._prompts:
188-
raise McpError(
189-
types.ErrorData(
190-
code=types.INVALID_PARAMS,
191-
message=f"{name} already exists in group prompts.",
216+
try:
217+
# Create a reverse index so we can find all prompts, resources, and
218+
# tools belonging to this session. Used for removing components from
219+
# the session group via self.disconnect_from_server.
220+
component_names = self._ComponentNames()
221+
222+
# Temporary components dicts. We do not want to modify the aggregate
223+
# lists in case of an intermediate failure.
224+
prompts_temp: dict[str, types.Prompt] = {}
225+
resources_temp: dict[str, types.Resource] = {}
226+
tools_temp: dict[str, types.Tool] = {}
227+
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
228+
229+
# Query the server for its prompts and aggregate to list.
230+
prompts = (await session.list_prompts()).prompts
231+
for prompt in prompts:
232+
name = self._component_name(prompt.name, server_info)
233+
if name in self._prompts:
234+
raise McpError(
235+
types.ErrorData(
236+
code=types.INVALID_PARAMS,
237+
message=f"{name} already exists in group prompts.",
238+
)
192239
)
193-
)
194-
prompts_temp[name] = prompt
195-
component_names.prompts.add(name)
196-
197-
# Query the server for its resources and aggregate to list.
198-
resources = (await session.list_resources()).resources
199-
for resource in resources:
200-
name = self._component_name(resource.name, server_info)
201-
if name in self._resources:
202-
raise McpError(
203-
types.ErrorData(
204-
code=types.INVALID_PARAMS,
205-
message=f"{name} already exists in group resources.",
240+
prompts_temp[name] = prompt
241+
component_names.prompts.add(name)
242+
243+
# Query the server for its resources and aggregate to list.
244+
resources = (await session.list_resources()).resources
245+
for resource in resources:
246+
name = self._component_name(resource.name, server_info)
247+
if name in self._resources:
248+
raise McpError(
249+
types.ErrorData(
250+
code=types.INVALID_PARAMS,
251+
message=f"{name} already exists in group resources.",
252+
)
206253
)
207-
)
208-
resources_temp[name] = resource
209-
component_names.resources.add(name)
210-
211-
# Query the server for its tools and aggregate to list.
212-
tools = (await session.list_tools()).tools
213-
for tool in tools:
214-
name = self._component_name(tool.name, server_info)
215-
if name in self._tools:
216-
raise McpError(
217-
types.ErrorData(
218-
code=types.INVALID_PARAMS,
219-
message=f"{name} already exists in group tools.",
254+
resources_temp[name] = resource
255+
component_names.resources.add(name)
256+
257+
# Query the server for its tools and aggregate to list.
258+
tools = (await session.list_tools()).tools
259+
for tool in tools:
260+
name = self._component_name(tool.name, server_info)
261+
if name in self._tools:
262+
raise McpError(
263+
types.ErrorData(
264+
code=types.INVALID_PARAMS,
265+
message=f"{name} already exists in group tools.",
266+
)
220267
)
221-
)
222-
tools_temp[name] = tool
223-
tool_to_session_temp[name] = session
224-
component_names.tools.add(name)
225-
226-
# Aggregate components.
227-
self._sessions[session] = component_names
228-
self._prompts.update(prompts_temp)
229-
self._resources.update(resources_temp)
230-
self._tools.update(tools_temp)
231-
self._tool_to_session.update(tool_to_session_temp)
232-
233-
return session
268+
tools_temp[name] = tool
269+
tool_to_session_temp[name] = session
270+
component_names.tools.add(name)
271+
272+
# Aggregate components.
273+
self._sessions[session] = component_names
274+
self._prompts.update(prompts_temp)
275+
self._resources.update(resources_temp)
276+
self._tools.update(tools_temp)
277+
self._tool_to_session.update(tool_to_session_temp)
278+
279+
return session
280+
except Exception:
281+
# If any error occurs during component fetching or registration
282+
# after session establishment, clean up the newly established
283+
# session's resources.
284+
if session in self._session_exit_stacks:
285+
session_stack_to_close = self._session_exit_stacks.pop(session)
286+
await session_stack_to_close.aclose()
287+
# The session_stack_to_close was also entered into
288+
# self._exit_stack. Its .aclose() being called here means when
289+
# self._exit_stack later tries to close it, it will be a
290+
# no-op, which is fine.
291+
raise
234292

235293
async def _establish_session(
236294
self, server_params: ServerParameters
237295
) -> tuple[types.Implementation, mcp.ClientSession]:
238296
"""Establish a client session to an MCP server."""
239297

240-
# Create read and write streams that facilitate io with the server.
241-
if isinstance(server_params, StdioServerParameters):
242-
client = mcp.stdio_client(server_params)
243-
read, write = await self._exit_stack.enter_async_context(client)
244-
elif isinstance(server_params, SseServerParameters):
245-
client = sse_client(
246-
url=server_params.url,
247-
headers=server_params.headers,
248-
timeout=server_params.timeout,
249-
sse_read_timeout=server_params.sse_read_timeout,
250-
)
251-
read, write = await self._exit_stack.enter_async_context(client)
252-
else:
253-
client = streamablehttp_client(
254-
url=server_params.url,
255-
headers=server_params.headers,
256-
timeout=server_params.timeout,
257-
sse_read_timeout=server_params.sse_read_timeout,
258-
terminate_on_close=server_params.terminate_on_close,
259-
)
260-
read, write, _ = await self._exit_stack.enter_async_context(client)
298+
session_specific_stack = contextlib.AsyncExitStack()
299+
try:
300+
# Create read and write streams that facilitate io with the server.
301+
if isinstance(server_params, StdioServerParameters):
302+
client = mcp.stdio_client(server_params)
303+
read, write = await self._exit_stack.enter_async_context(client)
304+
elif isinstance(server_params, SseServerParameters):
305+
client = sse_client(
306+
url=server_params.url,
307+
headers=server_params.headers,
308+
timeout=server_params.timeout,
309+
sse_read_timeout=server_params.sse_read_timeout,
310+
)
311+
read, write = await self._exit_stack.enter_async_context(client)
312+
else:
313+
client = streamablehttp_client(
314+
url=server_params.url,
315+
headers=server_params.headers,
316+
timeout=server_params.timeout,
317+
sse_read_timeout=server_params.sse_read_timeout,
318+
terminate_on_close=server_params.terminate_on_close,
319+
)
320+
read, write, _ = await self._exit_stack.enter_async_context(client)
261321

262-
session = await self._exit_stack.enter_async_context(
263-
mcp.ClientSession(read, write)
264-
)
265-
result = await session.initialize()
266-
return result.serverInfo, session
322+
session = await self._exit_stack.enter_async_context(
323+
mcp.ClientSession(read, write)
324+
)
325+
result = await session.initialize()
326+
327+
# Session successfully initialized.
328+
# Store its stack and register the stack with the main group stack.
329+
self._session_exit_stacks[session] = session_specific_stack
330+
# session_specific_stack itself becomes a resource managed by the
331+
# main _exit_stack.
332+
await self._exit_stack.enter_async_context(session_specific_stack)
333+
334+
return result.serverInfo, session
335+
except Exception:
336+
# If anything during this setup fails, ensure the session-specific
337+
# stack is closed.
338+
await session_specific_stack.aclose()
339+
raise
267340

268341
def _component_name(self, name: str, server_info: types.Implementation) -> str:
269342
if self._component_name_hook:

tests/client/test_session_group.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def name_hook(name: str, server_info: types.Implementation) -> str:
151151
assert group.tools[expected_tool_name] == mock_tool
152152
assert group._tool_to_session[expected_tool_name] == mock_session
153153

154-
def test_disconnect_from_server(self): # No mock arguments needed
154+
async def test_disconnect_from_server(self): # No mock arguments needed
155155
"""Test disconnecting from a server."""
156156
# --- Test Setup ---
157157
group = ClientSessionGroup()
@@ -205,7 +205,7 @@ def test_disconnect_from_server(self): # No mock arguments needed
205205
assert "prm1" in group._prompts
206206

207207
# --- Test Execution ---
208-
group.disconnect_from_server(mock_session)
208+
await group.disconnect_from_server(mock_session)
209209

210210
# --- Assertions ---
211211
assert mock_session not in group._sessions
@@ -263,12 +263,12 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta
263263
) # Ensure it's the original mock
264264

265265
# No patching needed here
266-
def test_disconnect_non_existent_server(self):
266+
async def test_disconnect_non_existent_server(self):
267267
"""Test disconnecting a server that isn't connected."""
268268
session = mock.Mock(spec=mcp.ClientSession)
269269
group = ClientSessionGroup()
270270
with pytest.raises(McpError):
271-
group.disconnect_from_server(session)
271+
await group.disconnect_from_server(session)
272272

273273
@pytest.mark.parametrize(
274274
"server_params_instance, client_type_name, patch_target_for_client_func",

0 commit comments

Comments
 (0)