1111import contextlib
1212from collections .abc import Callable
1313from datetime import timedelta
14+ from types import TracebackType
1415from typing import Any , TypeAlias
1516
1617from pydantic import BaseModel
18+ from typing_extensions import Self
1719
1820import mcp
1921from mcp import types
@@ -72,6 +74,14 @@ class ClientSessionGroup:
7274 For auxiliary handlers, such as resource subscription, this is delegated to
7375 the client and can be accessed via the session. For example:
7476 mcp_session_group.get_session("server_name").subscribe_to_resource(...)
77+
78+ Example Usage:
79+ name_fn = lambda name, server_info: f"{(server_info.name)}.{name}"
80+ with async ClientSessionGroup(component_name_hook=name_fn) as group:
81+ for server_params in server_params:
82+ group.connect_to_server(server_param)
83+ ...
84+
7585 """
7686
7787 class _ComponentNames (BaseModel ):
@@ -90,6 +100,7 @@ class _ComponentNames(BaseModel):
90100 _sessions : dict [mcp .ClientSession , _ComponentNames ]
91101 _tool_to_session : dict [str , mcp .ClientSession ]
92102 _exit_stack : contextlib .AsyncExitStack
103+ _session_exit_stacks : dict [mcp .ClientSession , contextlib .AsyncExitStack ]
93104
94105 # Optional fn consuming (component_name, serverInfo) for custom names.
95106 # This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +110,7 @@ class _ComponentNames(BaseModel):
99110
100111 def __init__ (
101112 self ,
102- exit_stack : contextlib .AsyncExitStack = contextlib . AsyncExitStack () ,
113+ exit_stack : contextlib .AsyncExitStack | None = None ,
103114 component_name_hook : _ComponentNameHook | None = None ,
104115 ) -> None :
105116 """Initializes the MCP client."""
@@ -110,9 +121,32 @@ def __init__(
110121
111122 self ._sessions = {}
112123 self ._tool_to_session = {}
113- self ._exit_stack = exit_stack
124+ self ._main_exit_stack = exit_stack or contextlib .AsyncExitStack ()
125+ self ._session_exit_stacks = {}
114126 self ._component_name_hook = component_name_hook
115127
128+ async def __aenter__ (self ) -> Self :
129+ # If ClientSessionGroup itself is managing the lifecycle of _main_exit_stack
130+ # (i.e., it created it), it should enter it.
131+ # If _main_exit_stack was passed in, it's assumed the caller manages
132+ # its entry/exit.
133+ # For simplicity and consistency with how AsyncExitStack is often used when
134+ # provided as a dependency, we might not need to enter it here if it's
135+ # managed externally. However, if this class is the primary owner, entering it
136+ # ensures its 'aclose' is called even if passed in. Let's assume the
137+ # passed-in stack is already entered by the caller if needed.
138+ # For now, we just return self as the main stack's lifecycle is tied to aclose.
139+ return self
140+
141+ async def __aexit__ (
142+ self ,
143+ _exc_type : type [BaseException ] | None ,
144+ _exc_val : BaseException | None ,
145+ _exc_tb : TracebackType | None ,
146+ ) -> bool | None :
147+ await self ._main_exit_stack .aclose ()
148+ return None # Do not suppress exceptio
149+
116150 @property
117151 def prompts (self ) -> dict [str , types .Prompt ]:
118152 """Returns the prompts as a dictionary of names to prompts."""
@@ -133,31 +167,42 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu
133167 session = self ._tool_to_session [name ]
134168 return await session .call_tool (name , args )
135169
136- def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
170+ async def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
137171 """Disconnects from a single MCP server."""
138172
139- if session not in self ._sessions :
173+ session_known_for_components = session in self ._sessions
174+ session_known_for_stack = session in self ._session_exit_stacks
175+
176+ if not session_known_for_components and not session_known_for_stack :
140177 raise McpError (
141178 types .ErrorData (
142179 code = types .INVALID_PARAMS ,
143- message = "Provided session is not being managed." ,
180+ message = "Provided session is not managed or already disconnected ." ,
144181 )
145182 )
146- component_names = self ._sessions [session ]
147183
148- # Remove prompts associated with the session.
149- for name in component_names .prompts :
150- del self ._prompts [name ]
151-
152- # Remove resources associated with the session.
153- for name in component_names .resources :
154- del self ._resources [name ]
155-
156- # Remove tools associated with the session.
157- for name in component_names .tools :
158- del self ._tools [name ]
159-
160- del self ._sessions [session ]
184+ if session_known_for_components :
185+ component_names = self ._sessions .pop (session ) # Pop from _sessions tracking
186+
187+ # Remove prompts associated with the session.
188+ for name in component_names .prompts :
189+ if name in self ._prompts :
190+ del self ._prompts [name ]
191+ # Remove resources associated with the session.
192+ for name in component_names .resources :
193+ if name in self ._resources :
194+ del self ._resources [name ]
195+ # Remove tools associated with the session.
196+ for name in component_names .tools :
197+ if name in self ._tools :
198+ del self ._tools [name ]
199+ if name in self ._tool_to_session :
200+ del self ._tool_to_session [name ]
201+
202+ # Clean up the session's resources via its dedicated exit stack
203+ if session_known_for_stack :
204+ session_stack_to_close = self ._session_exit_stacks .pop (session )
205+ await session_stack_to_close .aclose ()
161206
162207 async def connect_to_server (
163208 self ,
@@ -168,102 +213,130 @@ async def connect_to_server(
168213 # Establish server connection and create session.
169214 server_info , session = await self ._establish_session (server_params )
170215
171- # Create a reverse index so we can find all prompts, resources, and
172- # tools belonging to this session. Used for removing components from
173- # the session group via self.disconnect_from_server.
174- component_names = self ._ComponentNames ()
175-
176- # Temporary components dicts. We do not want to modify the aggregate
177- # lists in case of an intermediate failure.
178- prompts_temp : dict [str , types .Prompt ] = {}
179- resources_temp : dict [str , types .Resource ] = {}
180- tools_temp : dict [str , types .Tool ] = {}
181- tool_to_session_temp : dict [str , mcp .ClientSession ] = {}
182-
183- # Query the server for its prompts and aggregate to list.
184- prompts = (await session .list_prompts ()).prompts
185- for prompt in prompts :
186- name = self ._component_name (prompt .name , server_info )
187- if name in self ._prompts :
188- raise McpError (
189- types .ErrorData (
190- code = types .INVALID_PARAMS ,
191- message = f"{ name } already exists in group prompts." ,
216+ try :
217+ # Create a reverse index so we can find all prompts, resources, and
218+ # tools belonging to this session. Used for removing components from
219+ # the session group via self.disconnect_from_server.
220+ component_names = self ._ComponentNames ()
221+
222+ # Temporary components dicts. We do not want to modify the aggregate
223+ # lists in case of an intermediate failure.
224+ prompts_temp : dict [str , types .Prompt ] = {}
225+ resources_temp : dict [str , types .Resource ] = {}
226+ tools_temp : dict [str , types .Tool ] = {}
227+ tool_to_session_temp : dict [str , mcp .ClientSession ] = {}
228+
229+ # Query the server for its prompts and aggregate to list.
230+ prompts = (await session .list_prompts ()).prompts
231+ for prompt in prompts :
232+ name = self ._component_name (prompt .name , server_info )
233+ if name in self ._prompts :
234+ raise McpError (
235+ types .ErrorData (
236+ code = types .INVALID_PARAMS ,
237+ message = f"{ name } already exists in group prompts." ,
238+ )
192239 )
193- )
194- prompts_temp [ name ] = prompt
195- component_names . prompts . add ( name )
196-
197- # Query the server for its resources and aggregate to list.
198- resources = ( await session . list_resources ()). resources
199- for resource in resources :
200- name = self ._component_name ( resource . name , server_info )
201- if name in self . _resources :
202- raise McpError (
203- types .ErrorData (
204- code = types . INVALID_PARAMS ,
205- message = f" { name } already exists in group resources." ,
240+ prompts_temp [ name ] = prompt
241+ component_names . prompts . add ( name )
242+
243+ # Query the server for its resources and aggregate to list.
244+ resources = ( await session . list_resources ()). resources
245+ for resource in resources :
246+ name = self . _component_name ( resource . name , server_info )
247+ if name in self ._resources :
248+ raise McpError (
249+ types . ErrorData (
250+ code = types .INVALID_PARAMS ,
251+ message = f" { name } already exists in group resources." ,
252+ )
206253 )
207- )
208- resources_temp [ name ] = resource
209- component_names . resources . add ( name )
210-
211- # Query the server for its tools and aggregate to list.
212- tools = ( await session . list_tools ()). tools
213- for tool in tools :
214- name = self ._component_name ( tool . name , server_info )
215- if name in self . _tools :
216- raise McpError (
217- types .ErrorData (
218- code = types . INVALID_PARAMS ,
219- message = f" { name } already exists in group tools." ,
254+ resources_temp [ name ] = resource
255+ component_names . resources . add ( name )
256+
257+ # Query the server for its tools and aggregate to list.
258+ tools = ( await session . list_tools ()). tools
259+ for tool in tools :
260+ name = self . _component_name ( tool . name , server_info )
261+ if name in self ._tools :
262+ raise McpError (
263+ types . ErrorData (
264+ code = types .INVALID_PARAMS ,
265+ message = f" { name } already exists in group tools." ,
266+ )
220267 )
221- )
222- tools_temp [name ] = tool
223- tool_to_session_temp [name ] = session
224- component_names .tools .add (name )
225-
226- # Aggregate components.
227- self ._sessions [session ] = component_names
228- self ._prompts .update (prompts_temp )
229- self ._resources .update (resources_temp )
230- self ._tools .update (tools_temp )
231- self ._tool_to_session .update (tool_to_session_temp )
232-
233- return session
268+ tools_temp [name ] = tool
269+ tool_to_session_temp [name ] = session
270+ component_names .tools .add (name )
271+
272+ # Aggregate components.
273+ self ._sessions [session ] = component_names
274+ self ._prompts .update (prompts_temp )
275+ self ._resources .update (resources_temp )
276+ self ._tools .update (tools_temp )
277+ self ._tool_to_session .update (tool_to_session_temp )
278+
279+ return session
280+ except Exception :
281+ # If any error occurs during component fetching or registration
282+ # after session establishment, clean up the newly established
283+ # session's resources.
284+ if session in self ._session_exit_stacks :
285+ session_stack_to_close = self ._session_exit_stacks .pop (session )
286+ await session_stack_to_close .aclose ()
287+ # The session_stack_to_close was also entered into
288+ # self._exit_stack. Its .aclose() being called here means when
289+ # self._exit_stack later tries to close it, it will be a
290+ # no-op, which is fine.
291+ raise
234292
235293 async def _establish_session (
236294 self , server_params : ServerParameters
237295 ) -> tuple [types .Implementation , mcp .ClientSession ]:
238296 """Establish a client session to an MCP server."""
239297
240- # Create read and write streams that facilitate io with the server.
241- if isinstance (server_params , StdioServerParameters ):
242- client = mcp .stdio_client (server_params )
243- read , write = await self ._exit_stack .enter_async_context (client )
244- elif isinstance (server_params , SseServerParameters ):
245- client = sse_client (
246- url = server_params .url ,
247- headers = server_params .headers ,
248- timeout = server_params .timeout ,
249- sse_read_timeout = server_params .sse_read_timeout ,
250- )
251- read , write = await self ._exit_stack .enter_async_context (client )
252- else :
253- client = streamablehttp_client (
254- url = server_params .url ,
255- headers = server_params .headers ,
256- timeout = server_params .timeout ,
257- sse_read_timeout = server_params .sse_read_timeout ,
258- terminate_on_close = server_params .terminate_on_close ,
259- )
260- read , write , _ = await self ._exit_stack .enter_async_context (client )
298+ session_specific_stack = contextlib .AsyncExitStack ()
299+ try :
300+ # Create read and write streams that facilitate io with the server.
301+ if isinstance (server_params , StdioServerParameters ):
302+ client = mcp .stdio_client (server_params )
303+ read , write = await self ._exit_stack .enter_async_context (client )
304+ elif isinstance (server_params , SseServerParameters ):
305+ client = sse_client (
306+ url = server_params .url ,
307+ headers = server_params .headers ,
308+ timeout = server_params .timeout ,
309+ sse_read_timeout = server_params .sse_read_timeout ,
310+ )
311+ read , write = await self ._exit_stack .enter_async_context (client )
312+ else :
313+ client = streamablehttp_client (
314+ url = server_params .url ,
315+ headers = server_params .headers ,
316+ timeout = server_params .timeout ,
317+ sse_read_timeout = server_params .sse_read_timeout ,
318+ terminate_on_close = server_params .terminate_on_close ,
319+ )
320+ read , write , _ = await self ._exit_stack .enter_async_context (client )
261321
262- session = await self ._exit_stack .enter_async_context (
263- mcp .ClientSession (read , write )
264- )
265- result = await session .initialize ()
266- return result .serverInfo , session
322+ session = await self ._exit_stack .enter_async_context (
323+ mcp .ClientSession (read , write )
324+ )
325+ result = await session .initialize ()
326+
327+ # Session successfully initialized.
328+ # Store its stack and register the stack with the main group stack.
329+ self ._session_exit_stacks [session ] = session_specific_stack
330+ # session_specific_stack itself becomes a resource managed by the
331+ # main _exit_stack.
332+ await self ._exit_stack .enter_async_context (session_specific_stack )
333+
334+ return result .serverInfo , session
335+ except Exception :
336+ # If anything during this setup fails, ensure the session-specific
337+ # stack is closed.
338+ await session_specific_stack .aclose ()
339+ raise
267340
268341 def _component_name (self , name : str , server_info : types .Implementation ) -> str :
269342 if self ._component_name_hook :
0 commit comments