Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 51 additions & 28 deletions src/mcpm/router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.routing import Mount, Route
from starlette.types import AppType
from starlette.types import AppType, Lifespan

from mcpm.monitor.base import AccessEventType
from mcpm.monitor.event import trace_event
from mcpm.profile.profile_config import ProfileConfigManager
from mcpm.profile.profile_config import ProfileConfigManager, DEFAULT_PROFILE_PATH
from mcpm.schemas.server_config import ServerConfig
from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR

Expand All @@ -38,16 +38,17 @@ class MCPRouter:
exposes them as a single SSE server.
"""

def __init__(self, reload_server: bool = False) -> None:
def __init__(self, reload_server: bool = False, profile_path: str | None = DEFAULT_PROFILE_PATH) -> None:
"""Initialize the router."""
self.server_sessions: t.Dict[str, ServerConnection] = {}
self.capabilities_mapping: t.Dict[str, t.Dict[str, t.Any]] = defaultdict(dict)
self.tool_name_to_server_id: t.Dict[str, str] = {}
self.tools_mapping: t.Dict[str, t.Dict[str, t.Any]] = {}
self.prompts_mapping: t.Dict[str, t.Dict[str, t.Any]] = {}
self.resources_mapping: t.Dict[str, t.Dict[str, t.Any]] = {}
self.resources_templates_mapping: t.Dict[str, t.Dict[str, t.Any]] = {}
self.aggregated_server = self._create_aggregated_server()
self.profile_manager = ProfileConfigManager()
self.profile_manager = ProfileConfigManager(profile_path=profile_path)
self.watcher: Optional[ConfigWatcher] = None
if reload_server:
self.watcher = ConfigWatcher(self.profile_manager.profile_path)
Expand Down Expand Up @@ -120,10 +121,12 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
# Collect server tools, prompts, and resources
if response.capabilities.tools:
tools = await client.session.list_tools() # type: ignore
# Add tools with namespaced names, preserving existing tools
self.tools_mapping.update(
{f"{server_id}{TOOL_SPLITOR}{tool.name}": tool.model_dump() for tool in tools.tools}
)
for tool in tools.tools:
# To make sure tool name is unique across all servers
if tool.name in self.tool_name_to_server_id:
raise ValueError(f"Tool {tool.name} already exists. Please use unique tool names across all servers.")
self.tool_name_to_server_id[tool.name] = server_id
self.tools_mapping[f"{server_id}{TOOL_SPLITOR}{tool.name}"] = tool.model_dump()

if response.capabilities.prompts:
prompts = await client.session.list_prompts() # type: ignore
Expand Down Expand Up @@ -259,7 +262,7 @@ async def list_tools(req: types.ListToolsRequest) -> types.ServerResult:
for server_tool_id, tool in self.tools_mapping.items():
server_id, _ = parse_namespaced_id(server_tool_id, TOOL_SPLITOR)
if server_id in active_servers:
tool.update({"name": server_tool_id})
# Do not modify the tool name, since it's unique across all servers
tools.append(types.Tool(**tool))

if not tools:
Expand All @@ -271,17 +274,21 @@ async def list_tools(req: types.ListToolsRequest) -> types.ServerResult:
async def call_tool(req: types.CallToolRequest) -> types.ServerResult:
active_servers = get_active_servers(req.params.meta.profile) # type: ignore
logger.info(f"call_tool: {req} with active servers: {active_servers}")

server_id, tool_name = parse_namespaced_id(req.params.name, TOOL_SPLITOR)
if server_id is None or tool_name is None:

tool_name = req.params.name
server_id = self.tool_name_to_server_id.get(tool_name)
if server_id is None:
logger.debug(f"call_tool: {req} with tool_name: {tool_name}. Server ID {server_id} is not found")
return empty_result()
if server_id not in active_servers:
logger.debug(f"call_tool: {req} with tool_name: {tool_name}. Server ID {server_id} is not in active servers")
return empty_result()

try:
result = await self.server_sessions[server_id].session.call_tool(tool_name, req.params.arguments or {})
return types.ServerResult(result)
except Exception as e:
logger.error(f"Error calling tool {tool_name} on server {server_id}: {e}")
return types.ServerResult(
types.CallToolResult(
content=[types.TextContent(type="text", text=str(e))],
Expand Down Expand Up @@ -400,24 +407,25 @@ async def _initialize_server_capabilities(self):
capabilities=capabilities,
)

async def start_sse_server(
self, host: str = "localhost", port: int = 8080, allow_origins: t.Optional[t.List[str]] = None
) -> None:
async def get_sse_server_app(
self,
allow_origins: t.Optional[t.List[str]] = None,
include_lifespan: bool = True
) -> AppType:
"""
Start an SSE server that exposes the aggregated MCP server.
Get the SSE server app.

Args:
host: The host to bind to
port: The port to bind to
allow_origins: List of allowed origins for CORS
include_lifespan: Whether to include the router's lifespan manager in the app.

Returns:
An SSE server app
"""
# waiting all servers to be initialized
await self.initialize_router()

# Create SSE transport
sse = RouterSseTransport("/messages/")

# Handle SSE connections
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
request.scope,
Expand All @@ -430,12 +438,14 @@ async def handle_sse(request: Request) -> None:
self.aggregated_server.initialization_options,
)

@asynccontextmanager
async def lifespan(app: AppType):
yield
await self.shutdown()
lifespan_handler: t.Optional[Lifespan[AppType]] = None
if include_lifespan:
@asynccontextmanager
async def lifespan(app: AppType):
yield
await self.shutdown()
lifespan_handler = lifespan

# Set up middleware for CORS if needed
middleware: t.List[Middleware] = []
if allow_origins is not None:
middleware.append(
Expand All @@ -447,16 +457,29 @@ async def lifespan(app: AppType):
),
)

# Create Starlette app
app = Starlette(
debug=False,
middleware=middleware,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
lifespan=lifespan,
lifespan=lifespan_handler,
)
return app

async def start_sse_server(
self, host: str = "localhost", port: int = 8080, allow_origins: t.Optional[t.List[str]] = None
) -> None:
"""
Start an SSE server that exposes the aggregated MCP server.

Args:
host: The host to bind to
port: The port to bind to
allow_origins: List of allowed origins for CORS
"""
app = await self.get_sse_server_app(allow_origins)

# Configure and start the server
config = uvicorn.Config(
Expand Down
12 changes: 7 additions & 5 deletions src/mcpm/router/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
# maintain session_id to identifier mapping
profile = get_key_from_scope(scope, key_name="profile")
client_id = get_key_from_scope(scope, key_name="client")
if profile is not None:
self._session_id_to_identifier[session_id] = ClientIdentifier(
client_id=client_id or "anonymous", profile=profile
)
logger.debug(f"Session {session_id} mapped to identifier {self._session_id_to_identifier[session_id]}")
logger.debug(f"Profile: {profile}, Client ID: {client_id}")
client_id = client_id or "anonymous"
profile = profile or "default"
self._session_id_to_identifier[session_id] = ClientIdentifier(
client_id=client_id, profile=profile
)
logger.debug(f"Session {session_id} mapped to identifier {self._session_id_to_identifier[session_id]}")

sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0)

Expand Down
Loading