From 923e8f4d6c742dc4ebde10e221e06f31a1bfa7f6 Mon Sep 17 00:00:00 2001 From: calmini Date: Fri, 25 Apr 2025 18:55:25 +0800 Subject: [PATCH] fix: redirect the stderr of mcp servers to local --- src/mcpm/router/client_connection.py | 13 +++++++---- src/mcpm/router/router.py | 12 ++++++++-- src/mcpm/utils/errlog_manager.py | 34 ++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 7 deletions(-) create mode 100644 src/mcpm/utils/errlog_manager.py diff --git a/src/mcpm/router/client_connection.py b/src/mcpm/router/client_connection.py index a810c9ea..1a4f4ae1 100644 --- a/src/mcpm/router/client_connection.py +++ b/src/mcpm/router/client_connection.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Optional, cast +from typing import Optional, TextIO, cast from mcp import ClientSession, InitializeResult, StdioServerParameters, stdio_client from mcp.client.sse import sse_client @@ -10,10 +10,10 @@ logger = logging.getLogger(__name__) -def _stdio_transport_context(server_config: ServerConfig): +def _stdio_transport_context(server_config: ServerConfig, errlog: TextIO): server_config = cast(STDIOServerConfig, server_config) server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env) - return stdio_client(server_params) + return stdio_client(server_params, errlog=errlog) def _sse_transport_context(server_config: ServerConfig): @@ -22,16 +22,19 @@ def _sse_transport_context(server_config: ServerConfig): class ServerConnection: - def __init__(self, server_config: ServerConfig) -> None: + def __init__(self, server_config: ServerConfig, errlog: TextIO) -> None: self.session: Optional[ClientSession] = None self.session_initialized_response: Optional[InitializeResult] = None self._initialized = False self.server_config = server_config self._initialized_event = asyncio.Event() self._shutdown_event = asyncio.Event() + self._errlog = errlog self._transport_context_factory = ( - _stdio_transport_context if isinstance(server_config, STDIOServerConfig) else _sse_transport_context + lambda config: _stdio_transport_context(config, errlog=self._errlog) + if isinstance(config, STDIOServerConfig) + else _sse_transport_context(config) ) self._server_task = asyncio.create_task(self._server_lifespan_cycle()) diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index 8298ef45..64df52d2 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -6,7 +6,7 @@ import typing as t from collections import defaultdict from contextlib import asynccontextmanager -from typing import Literal, Optional +from typing import Literal, Optional, TextIO import uvicorn from mcp import server, types @@ -24,6 +24,7 @@ from mcpm.profile.profile_config import ProfileConfigManager from mcpm.schemas.server_config import ServerConfig from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR +from mcpm.utils.errlog_manager import ServerErrorLogManager from .client_connection import ServerConnection from .transport import RouterSseTransport @@ -60,6 +61,7 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = None, if reload_server: self.watcher = ConfigWatcher(self.profile_manager.profile_path) self.strict: bool = strict + self.error_log_manager = ServerErrorLogManager() def get_unique_servers(self) -> list[ServerConfig]: profiles = self.profile_manager.list_profiles() @@ -110,11 +112,13 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: raise ValueError(f"Server with ID {server_id} already exists") # Create client based on connection type - client = ServerConnection(server_config) + errlog: TextIO = self.error_log_manager.open_errlog_file(server_id) + client = ServerConnection(server_config, errlog=errlog) # Connect to the server await client.wait_for_initialization() if not client.healthy(): + self.error_log_manager.close_errlog_file(server_id) raise ValueError(f"Failed to connect to server {server_id}") response = client.session_initialized_response @@ -218,6 +222,7 @@ async def remove_server(self, server_id: str) -> None: # Remove the server from all collections del self.server_sessions[server_id] del self.capabilities_mapping[server_id] + self.error_log_manager.close_errlog_file(server_id) # Delete registered tools, resources and prompts for key in list(self.tools_mapping.keys()): @@ -574,4 +579,7 @@ async def shutdown(self): if client.healthy(): await client.request_for_shutdown() + # close all errlog files + self.error_log_manager.close_all() + logger.info("all alive client sessions have been shut down") diff --git a/src/mcpm/utils/errlog_manager.py b/src/mcpm/utils/errlog_manager.py new file mode 100644 index 00000000..f5d198fb --- /dev/null +++ b/src/mcpm/utils/errlog_manager.py @@ -0,0 +1,34 @@ +import os +from pathlib import Path +from typing import Optional, TextIO + +from .platform import get_log_directory + +DEFAULT_ROOT_STDERR_LOG_DIR = get_log_directory("mcpm") / "errlogs" + +class ServerErrorLogManager: + """ + A manager for server error logs. + """ + + def __init__(self, root_dir: Optional[Path] = None) -> None: + self.root_log_dir = root_dir if root_dir else DEFAULT_ROOT_STDERR_LOG_DIR + os.makedirs(self.root_log_dir, exist_ok=True) + self._log_files: dict[str, TextIO] = {} + + def open_errlog_file(self, server_id: str) -> TextIO: + if server_id not in self._log_files or self._log_files[server_id].closed: + log_file = self.root_log_dir / f"{server_id}.log" + # use line buffering, flush to disk when meeting a newline + self._log_files[server_id] = log_file.open("a", encoding="utf-8", buffering=1) + return self._log_files[server_id] + + def close_errlog_file(self, server_id: str) -> None: + if server_id in self._log_files and not self._log_files[server_id].closed: + self._log_files[server_id].flush() + self._log_files[server_id].close() + del self._log_files[server_id] + + def close_all(self) -> None: + for server_id in self._log_files: + self.close_errlog_file(server_id)