|
| 1 | +import logging |
| 2 | +from collections import defaultdict |
| 3 | +from typing import Any, Dict, TextIO, Union |
| 4 | + |
| 5 | +from mcp.server import NotificationOptions |
| 6 | +from mcp.types import ( |
| 7 | + LoggingCapability, |
| 8 | + Prompt, |
| 9 | + PromptsCapability, |
| 10 | + Resource, |
| 11 | + ResourcesCapability, |
| 12 | + ResourceTemplate, |
| 13 | + ServerCapabilities, |
| 14 | + Tool, |
| 15 | + ToolsCapability, |
| 16 | +) |
| 17 | +from pydantic import AnyUrl |
| 18 | + |
| 19 | +from mcpm.core.router.client_connection import ServerConnection |
| 20 | +from mcpm.core.schema import ResourceType, ServerConfig |
| 21 | +from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, TOOL_SPLITOR |
| 22 | +from mcpm.utils.errlog_manager import ServerErrorLogManager |
| 23 | + |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | +class MCPClientSessionManager: |
| 27 | + |
| 28 | + _instance = None |
| 29 | + |
| 30 | + @classmethod |
| 31 | + def get_instance(cls): |
| 32 | + if cls._instance is None: |
| 33 | + cls._instance = cls() |
| 34 | + return cls._instance |
| 35 | + |
| 36 | + def __init__(self, strict_mode: bool = False) -> None: |
| 37 | + """ |
| 38 | + MCPClientSessionManager serves as a central manager for MCP client sessions. |
| 39 | +
|
| 40 | + It manages connections to MCP servers, maintains session health, assembles server capabilities, |
| 41 | + and provides access to tools, prompts, and resources across all registered servers. |
| 42 | +
|
| 43 | + Args: |
| 44 | + strict_mode: If True, raises errors when duplicate capabilities are detected |
| 45 | + across servers. If False, automatically resolves conflicts by |
| 46 | + adding server name prefixes. |
| 47 | + """ |
| 48 | + self.strict_mode = strict_mode |
| 49 | + self.sessions: Dict[str, ServerConnection] = {} |
| 50 | + |
| 51 | + self.capabilities_mapping: Dict[str, Dict[str, Any]] = defaultdict(dict) |
| 52 | + self.capabilities_to_server_id: Dict[ResourceType, Dict[str, Any]] = defaultdict(dict) |
| 53 | + # real name mapping |
| 54 | + self.tools_mapping: Dict[str, Tool] = {} |
| 55 | + self.prompts_mapping: Dict[str, Prompt] = {} |
| 56 | + self.resources_mapping: Dict[str, Resource] = {} |
| 57 | + self.resources_templates_mapping: Dict[str, ResourceTemplate] = {} |
| 58 | + # error log stream |
| 59 | + self.error_log_manager: ServerErrorLogManager = ServerErrorLogManager() |
| 60 | + |
| 61 | + def get_alive_sessions(self): |
| 62 | + # returns the server_id list of all alive sessions |
| 63 | + return [server_id for server_id, session in self.sessions.items() if session.healthy()] |
| 64 | + |
| 65 | + async def _add_session_impl(self, server_id: str, server_config: ServerConfig): |
| 66 | + # initialize session |
| 67 | + if server_id in self.sessions: |
| 68 | + logger.warning(f"Server with ID {server_id} already exists") |
| 69 | + return |
| 70 | + |
| 71 | + errlog: TextIO = self.error_log_manager.open_errlog_file(server_id) |
| 72 | + client_session = ServerConnection(server_config, errlog=errlog) |
| 73 | + await client_session.wait_for_initialization() |
| 74 | + |
| 75 | + if not client_session.healthy(): |
| 76 | + logger.warning(f"Failed to connect to server {server_id}") |
| 77 | + self.error_log_manager.close_errlog_file(server_id) |
| 78 | + return |
| 79 | + |
| 80 | + self.sessions[server_id] = client_session |
| 81 | + |
| 82 | + # update capabilities |
| 83 | + initialized_response = client_session.session_initialized_response |
| 84 | + self.capabilities_mapping[server_id] = initialized_response.capabilities.model_dump() # type: ignore |
| 85 | + |
| 86 | + # update tools |
| 87 | + if initialized_response.capabilities.tools: # type: ignore |
| 88 | + await self._assemble_tools(server_id, client_session) |
| 89 | + # update prompts |
| 90 | + if initialized_response.capabilities.prompts: # type: ignore |
| 91 | + await self._assemble_prompts(server_id, client_session) |
| 92 | + # update resources |
| 93 | + if initialized_response.capabilities.resources: # type: ignore |
| 94 | + await self._assemble_resources(server_id, client_session) |
| 95 | + |
| 96 | + async def add_session(self, server_id: str, server_config: ServerConfig) -> bool: |
| 97 | + try: |
| 98 | + # add log to see whether client session initialization is blocked in this step |
| 99 | + logger.info(f"Ready to add server: {server_config.name}") |
| 100 | + await self._add_session_impl(server_id, server_config) |
| 101 | + logger.info(f"Server {server_config.name} added successfully") |
| 102 | + return True |
| 103 | + except Exception as e: |
| 104 | + # if went wrong, skip the update |
| 105 | + logger.error(f"Failed to add server {server_config.name}: {e}") |
| 106 | + # if error log is opened, close it |
| 107 | + self.error_log_manager.close_errlog_file(server_id) |
| 108 | + |
| 109 | + return False |
| 110 | + |
| 111 | + async def remove_session(self, server_id: str): |
| 112 | + if server_id not in self.sessions: |
| 113 | + logger.warning(f"Server with ID {server_id} does not exist") |
| 114 | + return |
| 115 | + |
| 116 | + client_session = self.sessions.pop(server_id) |
| 117 | + |
| 118 | + await client_session.request_for_shutdown() |
| 119 | + |
| 120 | + self.capabilities_mapping.pop(server_id) |
| 121 | + # close error log |
| 122 | + self.error_log_manager.close_errlog_file(server_id) |
| 123 | + # remove all mapping reference |
| 124 | + for key in list(self.tools_mapping.keys()): |
| 125 | + if self.capabilities_to_server_id[ResourceType.TOOL].get(key) == server_id: |
| 126 | + self.tools_mapping.pop(key) |
| 127 | + self.capabilities_to_server_id[ResourceType.TOOL].pop(key) |
| 128 | + for key in list(self.prompts_mapping.keys()): |
| 129 | + if self.capabilities_to_server_id[ResourceType.PROMPT].get(key) == server_id: |
| 130 | + self.prompts_mapping.pop(key) |
| 131 | + self.capabilities_to_server_id[ResourceType.PROMPT].pop(key) |
| 132 | + for key in list(self.resources_mapping.keys()): |
| 133 | + if self.capabilities_to_server_id[ResourceType.RESOURCE].get(key) == server_id: |
| 134 | + self.resources_mapping.pop(key) |
| 135 | + self.capabilities_to_server_id[ResourceType.RESOURCE].pop(key) |
| 136 | + for key in list(self.resources_templates_mapping.keys()): |
| 137 | + if self.capabilities_to_server_id[ResourceType.RESOURCE_TEMPLATE].get(key) == server_id: |
| 138 | + self.resources_templates_mapping.pop(key) |
| 139 | + self.capabilities_to_server_id[ResourceType.RESOURCE_TEMPLATE].pop(key) |
| 140 | + |
| 141 | + |
| 142 | + async def update_sessions(self, server_configs: list[ServerConfig]) -> tuple[list[str], list[str]]: |
| 143 | + """ |
| 144 | + Update client sessions based on the given server configs. |
| 145 | +
|
| 146 | + Args: |
| 147 | + server_configs: List of server configs to update. |
| 148 | +
|
| 149 | + Returns: |
| 150 | + Tuple of two lists: (server ids added, server ids removed) |
| 151 | + """ |
| 152 | + if not server_configs: |
| 153 | + return [], [] |
| 154 | + |
| 155 | + current_servers = self.get_alive_sessions() |
| 156 | + new_servers = [server_config.name for server_config in server_configs] |
| 157 | + |
| 158 | + server_configs_to_add = [ |
| 159 | + server_config for server_config in server_configs if server_config.name not in current_servers |
| 160 | + ] |
| 161 | + server_ids_to_remove = [server_id for server_id in current_servers if server_id not in new_servers] |
| 162 | + |
| 163 | + if server_configs_to_add: |
| 164 | + for server_config in server_configs_to_add: |
| 165 | + await self.add_session(server_config.name, server_config) |
| 166 | + |
| 167 | + if server_ids_to_remove: |
| 168 | + for server_id in server_ids_to_remove: |
| 169 | + await self.remove_session(server_id) |
| 170 | + logger.info(f"Server {server_id} removed successfully") |
| 171 | + |
| 172 | + return [ |
| 173 | + server_config.name for server_config in server_configs_to_add |
| 174 | + ], server_ids_to_remove |
| 175 | + |
| 176 | + |
| 177 | + async def _assemble_tools(self, server_id: str, client_session: ServerConnection): |
| 178 | + tools = await client_session.session.list_tools() # type: ignore |
| 179 | + for tool in tools.tools: |
| 180 | + tool_name = tool.name |
| 181 | + if tool_name in self.capabilities_to_server_id[ResourceType.TOOL]: |
| 182 | + if self.strict_mode: |
| 183 | + raise ValueError( |
| 184 | + f"Tool {tool_name} already exists. Please use unique tool names across all servers." |
| 185 | + ) |
| 186 | + else: |
| 187 | + # Auto resolve by adding server name prefix |
| 188 | + tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}" |
| 189 | + |
| 190 | + self.capabilities_to_server_id[ResourceType.TOOL][tool_name] = server_id |
| 191 | + self.tools_mapping[tool_name] = tool |
| 192 | + |
| 193 | + async def _assemble_prompts(self, server_id: str, client_session: ServerConnection): |
| 194 | + prompts = await client_session.session.list_prompts() # type: ignore |
| 195 | + for prompt in prompts.prompts: |
| 196 | + prompt_name = prompt.name |
| 197 | + if prompt_name in self.capabilities_to_server_id[ResourceType.PROMPT]: |
| 198 | + if self.strict_mode: |
| 199 | + raise ValueError( |
| 200 | + f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers." |
| 201 | + ) |
| 202 | + else: |
| 203 | + # Auto resolve by adding server name prefix |
| 204 | + prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}" |
| 205 | + |
| 206 | + self.capabilities_to_server_id[ResourceType.PROMPT][prompt_name] = server_id |
| 207 | + self.prompts_mapping[prompt_name] = prompt |
| 208 | + |
| 209 | + async def _assemble_resources(self, server_id: str, client_session: ServerConnection): |
| 210 | + resources = await client_session.session.list_resources() # type: ignore |
| 211 | + for resource in resources.resources: |
| 212 | + resource_uri = resource.uri |
| 213 | + if str(resource_uri) in self.capabilities_to_server_id[ResourceType.RESOURCE]: |
| 214 | + if self.strict_mode: |
| 215 | + raise ValueError( |
| 216 | + f"Resource {resource_uri} already exists. Please use unique resource names across all servers." |
| 217 | + ) |
| 218 | + else: |
| 219 | + # Auto resolve by adding server name prefix |
| 220 | + host = resource_uri.host |
| 221 | + resource_uri = AnyUrl.build( |
| 222 | + host=f"{server_id}{RESOURCE_SPLITOR}{host}", |
| 223 | + scheme=resource_uri.scheme, |
| 224 | + path=resource_uri.path, |
| 225 | + username=resource_uri.username, |
| 226 | + password=resource_uri.password, |
| 227 | + port=resource_uri.port, |
| 228 | + query=resource_uri.query, |
| 229 | + fragment=resource_uri.fragment, |
| 230 | + ) |
| 231 | + |
| 232 | + self.capabilities_to_server_id[ResourceType.RESOURCE][str(resource_uri)] = server_id |
| 233 | + self.resources_mapping[str(resource_uri)] = resource |
| 234 | + |
| 235 | + resource_templates = await client_session.session.list_resource_templates() # type: ignore |
| 236 | + for resource_template in resource_templates.resourceTemplates: |
| 237 | + resource_template_uri_template = resource_template.uriTemplate |
| 238 | + if resource_template_uri_template in self.capabilities_to_server_id[ResourceType.RESOURCE_TEMPLATE]: |
| 239 | + if self.strict_mode: |
| 240 | + raise ValueError( |
| 241 | + f"Resource template {resource_template_uri_template} already exists. Please use unique resource template names across all servers." |
| 242 | + ) |
| 243 | + else: |
| 244 | + # Auto resolve by adding server name prefix |
| 245 | + resource_template_uri_template = f"{server_id}{RESOURCE_SPLITOR}{resource_template_uri_template}" |
| 246 | + |
| 247 | + self.capabilities_to_server_id[ResourceType.RESOURCE_TEMPLATE][resource_template_uri_template] = server_id |
| 248 | + self.resources_templates_mapping[resource_template_uri_template] = resource_template |
| 249 | + |
| 250 | + def get_session(self, server_id: str) -> ServerConnection | None: |
| 251 | + # get the client session by server_id |
| 252 | + return self.sessions.get(server_id) |
| 253 | + |
| 254 | + def get_capability_server_id(self, resource_type: ResourceType, resource_name: str) -> str | None: |
| 255 | + # get the server_id by resource_type and resource_name |
| 256 | + return self.capabilities_to_server_id[resource_type].get(resource_name) |
| 257 | + |
| 258 | + def get_resource_schema(self, resource_type: ResourceType, resource_name: str) -> Union[Tool, Prompt, Resource, ResourceTemplate, None]: |
| 259 | + if resource_type == ResourceType.TOOL: |
| 260 | + return self.tools_mapping.get(resource_name) |
| 261 | + elif resource_type == ResourceType.PROMPT: |
| 262 | + return self.prompts_mapping.get(resource_name) |
| 263 | + elif resource_type == ResourceType.RESOURCE: |
| 264 | + return self.resources_mapping.get(resource_name) |
| 265 | + elif resource_type == ResourceType.RESOURCE_TEMPLATE: |
| 266 | + return self.resources_templates_mapping.get(resource_name) |
| 267 | + else: |
| 268 | + return None |
| 269 | + |
| 270 | + def get_aggregated_server_capabilities(self) -> ServerCapabilities: |
| 271 | + # for initialization of aggregated server |
| 272 | + notification_options = NotificationOptions( |
| 273 | + prompts_changed=False, |
| 274 | + resources_changed=False, |
| 275 | + tools_changed=False, |
| 276 | + ) |
| 277 | + |
| 278 | + has_prompts = any( |
| 279 | + server_capabilities.get("prompts") for server_capabilities in self.capabilities_mapping.values() |
| 280 | + ) |
| 281 | + has_resources = any( |
| 282 | + server_capabilities.get("resources") for server_capabilities in self.capabilities_mapping.values() |
| 283 | + ) |
| 284 | + has_tools = any(server_capabilities.get("tools") for server_capabilities in self.capabilities_mapping.values()) |
| 285 | + has_logging = any( |
| 286 | + server_capabilities.get("logging") for server_capabilities in self.capabilities_mapping.values() |
| 287 | + ) |
| 288 | + |
| 289 | + prompts_capability = ( |
| 290 | + PromptsCapability(listChanged=notification_options.prompts_changed) if has_prompts else None |
| 291 | + ) |
| 292 | + resources_capability = ( |
| 293 | + ResourcesCapability(subscribe=False, listChanged=notification_options.resources_changed) |
| 294 | + if has_resources |
| 295 | + else None |
| 296 | + ) |
| 297 | + tools_capability = ToolsCapability(listChanged=notification_options.tools_changed) if has_tools else None |
| 298 | + logging_capability = LoggingCapability() if has_logging else None |
| 299 | + |
| 300 | + return ServerCapabilities( |
| 301 | + prompts=prompts_capability, |
| 302 | + resources=resources_capability, |
| 303 | + tools=tools_capability, |
| 304 | + logging=logging_capability, |
| 305 | + experimental={}, |
| 306 | + ) |
| 307 | + |
| 308 | + |
| 309 | + async def shutdown(self): |
| 310 | + for session in self.sessions.values(): |
| 311 | + await session.request_for_shutdown() |
| 312 | + |
| 313 | + self.error_log_manager.close_all() |
0 commit comments