Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/mcpm/router/client_connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Optional, cast
from typing import Optional, TextIO, cast

from mcp import ClientSession, InitializeResult, StdioServerParameters, stdio_client
from mcp.client.sse import sse_client
Expand All @@ -10,10 +10,10 @@
logger = logging.getLogger(__name__)


def _stdio_transport_context(server_config: ServerConfig):
def _stdio_transport_context(server_config: ServerConfig, errlog: TextIO):
server_config = cast(STDIOServerConfig, server_config)
server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env)
return stdio_client(server_params)
return stdio_client(server_params, errlog=errlog)


def _sse_transport_context(server_config: ServerConfig):
Expand All @@ -22,16 +22,19 @@ def _sse_transport_context(server_config: ServerConfig):


class ServerConnection:
def __init__(self, server_config: ServerConfig) -> None:
def __init__(self, server_config: ServerConfig, errlog: TextIO) -> None:
self.session: Optional[ClientSession] = None
self.session_initialized_response: Optional[InitializeResult] = None
self._initialized = False
self.server_config = server_config
self._initialized_event = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._errlog = errlog

self._transport_context_factory = (
_stdio_transport_context if isinstance(server_config, STDIOServerConfig) else _sse_transport_context
lambda config: _stdio_transport_context(config, errlog=self._errlog)
if isinstance(config, STDIOServerConfig)
else _sse_transport_context(config)
)

self._server_task = asyncio.create_task(self._server_lifespan_cycle())
Expand Down
12 changes: 10 additions & 2 deletions src/mcpm/router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing as t
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import Literal, Optional
from typing import Literal, Optional, TextIO

import uvicorn
from mcp import server, types
Expand All @@ -24,6 +24,7 @@
from mcpm.profile.profile_config import ProfileConfigManager
from mcpm.schemas.server_config import ServerConfig
from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR
from mcpm.utils.errlog_manager import ServerErrorLogManager

from .client_connection import ServerConnection
from .transport import RouterSseTransport
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = None,
if reload_server:
self.watcher = ConfigWatcher(self.profile_manager.profile_path)
self.strict: bool = strict
self.error_log_manager = ServerErrorLogManager()

def get_unique_servers(self) -> list[ServerConfig]:
profiles = self.profile_manager.list_profiles()
Expand Down Expand Up @@ -110,11 +112,13 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
raise ValueError(f"Server with ID {server_id} already exists")

# Create client based on connection type
client = ServerConnection(server_config)
errlog: TextIO = self.error_log_manager.open_errlog_file(server_id)
client = ServerConnection(server_config, errlog=errlog)

# Connect to the server
await client.wait_for_initialization()
if not client.healthy():
self.error_log_manager.close_errlog_file(server_id)
raise ValueError(f"Failed to connect to server {server_id}")

response = client.session_initialized_response
Expand Down Expand Up @@ -218,6 +222,7 @@ async def remove_server(self, server_id: str) -> None:
# Remove the server from all collections
del self.server_sessions[server_id]
del self.capabilities_mapping[server_id]
self.error_log_manager.close_errlog_file(server_id)

# Delete registered tools, resources and prompts
for key in list(self.tools_mapping.keys()):
Expand Down Expand Up @@ -574,4 +579,7 @@ async def shutdown(self):
if client.healthy():
await client.request_for_shutdown()

# close all errlog files
self.error_log_manager.close_all()

logger.info("all alive client sessions have been shut down")
34 changes: 34 additions & 0 deletions src/mcpm/utils/errlog_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from pathlib import Path
from typing import Optional, TextIO

from .platform import get_log_directory

DEFAULT_ROOT_STDERR_LOG_DIR = get_log_directory("mcpm") / "errlogs"

class ServerErrorLogManager:
"""
A manager for server error logs.
"""

def __init__(self, root_dir: Optional[Path] = None) -> None:
self.root_log_dir = root_dir if root_dir else DEFAULT_ROOT_STDERR_LOG_DIR
os.makedirs(self.root_log_dir, exist_ok=True)
self._log_files: dict[str, TextIO] = {}

def open_errlog_file(self, server_id: str) -> TextIO:
if server_id not in self._log_files or self._log_files[server_id].closed:
log_file = self.root_log_dir / f"{server_id}.log"
# use line buffering, flush to disk when meeting a newline
self._log_files[server_id] = log_file.open("a", encoding="utf-8", buffering=1)
return self._log_files[server_id]

def close_errlog_file(self, server_id: str) -> None:
if server_id in self._log_files and not self._log_files[server_id].closed:
self._log_files[server_id].flush()
self._log_files[server_id].close()
del self._log_files[server_id]

def close_all(self) -> None:
for server_id in self._log_files:
self.close_errlog_file(server_id)