Skip to content

Commit 4e5bc71

Browse files
committed
refactor: add mcp client session manager
1 parent 2c2bed0 commit 4e5bc71

File tree

3 files changed

+392
-1
lines changed

3 files changed

+392
-1
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import asyncio
2+
import logging
3+
from typing import Optional, TextIO, cast
4+
5+
from mcp import ClientSession, InitializeResult, StdioServerParameters, stdio_client
6+
from mcp.client.sse import sse_client
7+
8+
from mcpm.core.schema import ServerConfig, SSEServerConfig, STDIOServerConfig
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def _stdio_transport_context(server_config: ServerConfig, errlog: TextIO):
14+
server_config = cast(STDIOServerConfig, server_config)
15+
server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env)
16+
return stdio_client(server_params, errlog=errlog)
17+
18+
19+
def _sse_transport_context(server_config: ServerConfig):
20+
server_config = cast(SSEServerConfig, server_config)
21+
return sse_client(server_config.url, headers=server_config.headers)
22+
23+
24+
class ServerConnection:
25+
def __init__(self, server_config: ServerConfig, errlog: TextIO) -> None:
26+
self.session: Optional[ClientSession] = None
27+
self.session_initialized_response: Optional[InitializeResult] = None
28+
self._initialized = False
29+
self.server_config = server_config
30+
self._initialized_event = asyncio.Event()
31+
self._shutdown_event = asyncio.Event()
32+
self._errlog = errlog
33+
34+
self._transport_context_factory = (
35+
lambda config: _stdio_transport_context(config, errlog=self._errlog)
36+
if isinstance(config, STDIOServerConfig)
37+
else _sse_transport_context(config)
38+
)
39+
40+
self._server_task = asyncio.create_task(self._server_lifespan_cycle())
41+
42+
def healthy(self) -> bool:
43+
return self.session is not None and self._initialized
44+
45+
# block until client session is initialized
46+
async def wait_for_initialization(self):
47+
await self._initialized_event.wait()
48+
49+
# request for client session to gracefully close
50+
async def request_for_shutdown(self):
51+
self._shutdown_event.set()
52+
53+
# block until client session is shutdown
54+
async def wait_for_shutdown_request(self):
55+
await self._shutdown_event.wait()
56+
57+
async def _server_lifespan_cycle(self):
58+
try:
59+
async with self._transport_context_factory(self.server_config) as (read, write):
60+
async with ClientSession(read, write) as session:
61+
self.session_initialized_response = await session.initialize()
62+
63+
self.session = session
64+
self._initialized = True
65+
self._initialized_event.set()
66+
# block here so that the session will not be closed after exit scope
67+
# we could retrieve alive session through self.session
68+
await self.wait_for_shutdown_request()
69+
except Exception as e:
70+
logger.error(f"Failed to connect to server {self.server_config.name}: {e}")
71+
self._initialized_event.set()
72+
self._shutdown_event.set()

src/mcpm/core/router/manager.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
import logging
2+
from collections import defaultdict
3+
from typing import Any, Dict, TextIO, Union
4+
5+
from mcp.server import NotificationOptions
6+
from mcp.types import (
7+
LoggingCapability,
8+
Prompt,
9+
PromptsCapability,
10+
Resource,
11+
ResourcesCapability,
12+
ResourceTemplate,
13+
ServerCapabilities,
14+
Tool,
15+
ToolsCapability,
16+
)
17+
from pydantic import AnyUrl
18+
19+
from mcpm.core.router.client_connection import ServerConnection
20+
from mcpm.core.schema import ResourceType, ServerConfig
21+
from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, TOOL_SPLITOR
22+
from mcpm.utils.errlog_manager import ServerErrorLogManager
23+
24+
logger = logging.getLogger(__name__)
25+
26+
class MCPClientSessionManager:
27+
28+
_instance = None
29+
30+
@classmethod
31+
def get_instance(cls):
32+
if cls._instance is None:
33+
cls._instance = cls()
34+
return cls._instance
35+
36+
def __init__(self, strict_mode: bool = False) -> None:
37+
"""
38+
MCPClientSessionManager serves as a central manager for MCP client sessions.
39+
40+
It manages connections to MCP servers, maintains session health, assembles server capabilities,
41+
and provides access to tools, prompts, and resources across all registered servers.
42+
43+
Args:
44+
strict_mode: If True, raises errors when duplicate capabilities are detected
45+
across servers. If False, automatically resolves conflicts by
46+
adding server name prefixes.
47+
"""
48+
self.strict_mode = strict_mode
49+
self.sessions: Dict[str, ServerConnection] = {}
50+
51+
self.capabilities_mapping: Dict[str, Dict[str, Any]] = defaultdict(dict)
52+
self.capabilities_to_server_id: Dict[ResourceType, Dict[str, Any]] = defaultdict(dict)
53+
# real name mapping
54+
self.tools_mapping: Dict[str, Tool] = {}
55+
self.prompts_mapping: Dict[str, Prompt] = {}
56+
self.resources_mapping: Dict[str, Resource] = {}
57+
self.resources_templates_mapping: Dict[str, ResourceTemplate] = {}
58+
# error log stream
59+
self.error_log_manager: ServerErrorLogManager = ServerErrorLogManager()
60+
61+
def get_alive_sessions(self):
62+
# returns the server_id list of all alive sessions
63+
return [server_id for server_id, session in self.sessions.items() if session.healthy()]
64+
65+
async def _add_session_impl(self, server_id: str, server_config: ServerConfig):
66+
# initialize session
67+
if server_id in self.sessions:
68+
logger.warning(f"Server with ID {server_id} already exists")
69+
return
70+
71+
errlog: TextIO = self.error_log_manager.open_errlog_file(server_id)
72+
client_session = ServerConnection(server_config, errlog=errlog)
73+
await client_session.wait_for_initialization()
74+
75+
if not client_session.healthy():
76+
logger.warning(f"Failed to connect to server {server_id}")
77+
self.error_log_manager.close_errlog_file(server_id)
78+
return
79+
80+
self.sessions[server_id] = client_session
81+
82+
# update capabilities
83+
initialized_response = client_session.session_initialized_response
84+
self.capabilities_mapping[server_id] = initialized_response.capabilities.model_dump() # type: ignore
85+
86+
# update tools
87+
if initialized_response.capabilities.tools: # type: ignore
88+
await self._assemble_tools(server_id, client_session)
89+
# update prompts
90+
if initialized_response.capabilities.prompts: # type: ignore
91+
await self._assemble_prompts(server_id, client_session)
92+
# update resources
93+
if initialized_response.capabilities.resources: # type: ignore
94+
await self._assemble_resources(server_id, client_session)
95+
96+
async def add_session(self, server_id: str, server_config: ServerConfig) -> bool:
97+
try:
98+
# add log to see whether client session initialization is blocked in this step
99+
logger.info(f"Ready to add server: {server_config.name}")
100+
await self._add_session_impl(server_id, server_config)
101+
logger.info(f"Server {server_config.name} added successfully")
102+
return True
103+
except Exception as e:
104+
# if went wrong, skip the update
105+
logger.error(f"Failed to add server {server_config.name}: {e}")
106+
# if error log is opened, close it
107+
self.error_log_manager.close_errlog_file(server_id)
108+
109+
return False
110+
111+
async def remove_session(self, server_id: str):
112+
if server_id not in self.sessions:
113+
logger.warning(f"Server with ID {server_id} does not exist")
114+
return
115+
116+
client_session = self.sessions.pop(server_id)
117+
118+
await client_session.request_for_shutdown()
119+
120+
self.capabilities_mapping.pop(server_id)
121+
# close error log
122+
self.error_log_manager.close_errlog_file(server_id)
123+
# remove all mapping reference
124+
for key in list(self.tools_mapping.keys()):
125+
if self.capabilities_to_server_id[ResourceType.TOOL].get(key) == server_id:
126+
self.tools_mapping.pop(key)
127+
self.capabilities_to_server_id[ResourceType.TOOL].pop(key)
128+
for key in list(self.prompts_mapping.keys()):
129+
if self.capabilities_to_server_id[ResourceType.PROMPT].get(key) == server_id:
130+
self.prompts_mapping.pop(key)
131+
self.capabilities_to_server_id[ResourceType.PROMPT].pop(key)
132+
for key in list(self.resources_mapping.keys()):
133+
if self.capabilities_to_server_id[ResourceType.RESOURCE].get(key) == server_id:
134+
self.resources_mapping.pop(key)
135+
self.capabilities_to_server_id[ResourceType.RESOURCE].pop(key)
136+
for key in list(self.resources_templates_mapping.keys()):
137+
if self.capabilities_to_server_id[ResourceType.RESOURCE_TEMPLATE].get(key) == server_id:
138+
self.resources_templates_mapping.pop(key)
139+
self.capabilities_to_server_id[ResourceType.RESOURCE_TEMPLATE].pop(key)
140+
141+
142+
async def update_sessions(self, server_configs: list[ServerConfig]) -> tuple[list[str], list[str]]:
143+
"""
144+
Update client sessions based on the given server configs.
145+
146+
Args:
147+
server_configs: List of server configs to update.
148+
149+
Returns:
150+
Tuple of two lists: (server ids added, server ids removed)
151+
"""
152+
if not server_configs:
153+
return [], []
154+
155+
current_servers = self.get_alive_sessions()
156+
new_servers = [server_config.name for server_config in server_configs]
157+
158+
server_configs_to_add = [
159+
server_config for server_config in server_configs if server_config.name not in current_servers
160+
]
161+
server_ids_to_remove = [server_id for server_id in current_servers if server_id not in new_servers]
162+
163+
if server_configs_to_add:
164+
for server_config in server_configs_to_add:
165+
await self.add_session(server_config.name, server_config)
166+
167+
if server_ids_to_remove:
168+
for server_id in server_ids_to_remove:
169+
await self.remove_session(server_id)
170+
logger.info(f"Server {server_id} removed successfully")
171+
172+
return [
173+
server_config.name for server_config in server_configs_to_add
174+
], server_ids_to_remove
175+
176+
177+
async def _assemble_tools(self, server_id: str, client_session: ServerConnection):
178+
tools = await client_session.session.list_tools() # type: ignore
179+
for tool in tools.tools:
180+
tool_name = tool.name
181+
if tool_name in self.capabilities_to_server_id[ResourceType.TOOL]:
182+
if self.strict_mode:
183+
raise ValueError(
184+
f"Tool {tool_name} already exists. Please use unique tool names across all servers."
185+
)
186+
else:
187+
# Auto resolve by adding server name prefix
188+
tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}"
189+
190+
self.capabilities_to_server_id[ResourceType.TOOL][tool_name] = server_id
191+
self.tools_mapping[tool_name] = tool
192+
193+
async def _assemble_prompts(self, server_id: str, client_session: ServerConnection):
194+
prompts = await client_session.session.list_prompts() # type: ignore
195+
for prompt in prompts.prompts:
196+
prompt_name = prompt.name
197+
if prompt_name in self.capabilities_to_server_id[ResourceType.PROMPT]:
198+
if self.strict_mode:
199+
raise ValueError(
200+
f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers."
201+
)
202+
else:
203+
# Auto resolve by adding server name prefix
204+
prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}"
205+
206+
self.capabilities_to_server_id[ResourceType.PROMPT][prompt_name] = server_id
207+
self.prompts_mapping[prompt_name] = prompt
208+
209+
async def _assemble_resources(self, server_id: str, client_session: ServerConnection):
210+
resources = await client_session.session.list_resources() # type: ignore
211+
for resource in resources.resources:
212+
resource_uri = resource.uri
213+
if str(resource_uri) in self.capabilities_to_server_id[ResourceType.RESOURCE]:
214+
if self.strict_mode:
215+
raise ValueError(
216+
f"Resource {resource_uri} already exists. Please use unique resource names across all servers."
217+
)
218+
else:
219+
# Auto resolve by adding server name prefix
220+
host = resource_uri.host
221+
resource_uri = AnyUrl.build(
222+
host=f"{server_id}{RESOURCE_SPLITOR}{host}",
223+
scheme=resource_uri.scheme,
224+
path=resource_uri.path,
225+
username=resource_uri.username,
226+
password=resource_uri.password,
227+
port=resource_uri.port,
228+
query=resource_uri.query,
229+
fragment=resource_uri.fragment,
230+
)
231+
232+
self.capabilities_to_server_id[ResourceType.RESOURCE][str(resource_uri)] = server_id
233+
self.resources_mapping[str(resource_uri)] = resource
234+
235+
resource_templates = await client_session.session.list_resource_templates() # type: ignore
236+
for resource_template in resource_templates.resourceTemplates:
237+
resource_template_uri_template = resource_template.uriTemplate
238+
if resource_template_uri_template in self.capabilities_to_server_id[ResourceType.RESOURCE_TEMPLATE]:
239+
if self.strict_mode:
240+
raise ValueError(
241+
f"Resource template {resource_template_uri_template} already exists. Please use unique resource template names across all servers."
242+
)
243+
else:
244+
# Auto resolve by adding server name prefix
245+
resource_template_uri_template = f"{server_id}{RESOURCE_SPLITOR}{resource_template_uri_template}"
246+
247+
self.capabilities_to_server_id[ResourceType.RESOURCE_TEMPLATE][resource_template_uri_template] = server_id
248+
self.resources_templates_mapping[resource_template_uri_template] = resource_template
249+
250+
def get_session(self, server_id: str) -> ServerConnection | None:
251+
# get the client session by server_id
252+
return self.sessions.get(server_id)
253+
254+
def get_capability_server_id(self, resource_type: ResourceType, resource_name: str) -> str | None:
255+
# get the server_id by resource_type and resource_name
256+
return self.capabilities_to_server_id[resource_type].get(resource_name)
257+
258+
def get_resource_schema(self, resource_type: ResourceType, resource_name: str) -> Union[Tool, Prompt, Resource, ResourceTemplate, None]:
259+
if resource_type == ResourceType.TOOL:
260+
return self.tools_mapping.get(resource_name)
261+
elif resource_type == ResourceType.PROMPT:
262+
return self.prompts_mapping.get(resource_name)
263+
elif resource_type == ResourceType.RESOURCE:
264+
return self.resources_mapping.get(resource_name)
265+
elif resource_type == ResourceType.RESOURCE_TEMPLATE:
266+
return self.resources_templates_mapping.get(resource_name)
267+
else:
268+
return None
269+
270+
def get_aggregated_server_capabilities(self) -> ServerCapabilities:
271+
# for initialization of aggregated server
272+
notification_options = NotificationOptions(
273+
prompts_changed=False,
274+
resources_changed=False,
275+
tools_changed=False,
276+
)
277+
278+
has_prompts = any(
279+
server_capabilities.get("prompts") for server_capabilities in self.capabilities_mapping.values()
280+
)
281+
has_resources = any(
282+
server_capabilities.get("resources") for server_capabilities in self.capabilities_mapping.values()
283+
)
284+
has_tools = any(server_capabilities.get("tools") for server_capabilities in self.capabilities_mapping.values())
285+
has_logging = any(
286+
server_capabilities.get("logging") for server_capabilities in self.capabilities_mapping.values()
287+
)
288+
289+
prompts_capability = (
290+
PromptsCapability(listChanged=notification_options.prompts_changed) if has_prompts else None
291+
)
292+
resources_capability = (
293+
ResourcesCapability(subscribe=False, listChanged=notification_options.resources_changed)
294+
if has_resources
295+
else None
296+
)
297+
tools_capability = ToolsCapability(listChanged=notification_options.tools_changed) if has_tools else None
298+
logging_capability = LoggingCapability() if has_logging else None
299+
300+
return ServerCapabilities(
301+
prompts=prompts_capability,
302+
resources=resources_capability,
303+
tools=tools_capability,
304+
logging=logging_capability,
305+
experimental={},
306+
)
307+
308+
309+
async def shutdown(self):
310+
for session in self.sessions.values():
311+
await session.request_for_shutdown()
312+
313+
self.error_log_manager.close_all()

0 commit comments

Comments
 (0)