Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 135 additions & 67 deletions src/mcpm/router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@
from mcpm.monitor.event import trace_event
from mcpm.profile.profile_config import ProfileConfigManager
from mcpm.schemas.server_config import ServerConfig
from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR
from mcpm.utils.config import (
DEFAULT_HOST,
DEFAULT_PORT,
DEFAULT_SHARE_ADDRESS,
PROMPT_SPLITOR,
RESOURCE_SPLITOR,
RESOURCE_TEMPLATE_SPLITOR,
TOOL_SPLITOR,
ConfigManager,
)

from .client_connection import ServerConnection
from .transport import RouterSseTransport
Expand All @@ -36,16 +45,42 @@ class MCPRouter:
"""
A router that aggregates multiple MCP servers (SSE/STDIO) and
exposes them as a single SSE server.

Example:
```python
# Initialize with a custom API key
router = MCPRouter(api_key="your-api-key")

# Initialize with custom router configuration
router_config = {
"host": "localhost",
"port": 8080,
"share_address": "custom.share.address:8080"
}
router = MCPRouter(api_key="your-api-key", router_config=router_config)

# Create a global config from the router's configuration
router.create_global_config()
```
"""

def __init__(self, reload_server: bool = False, profile_path: str | None = None, strict: bool = False) -> None:
def __init__(
self,
reload_server: bool = False,
profile_path: str | None = None,
strict: bool = False,
api_key: str | None = None,
router_config: dict | None = None,
) -> None:
"""
Initialize the router.

:param reload_server: Whether to reload the server when the config changes
:param profile_path: Path to the profile file
:param strict: Whether to use strict mode for duplicated tool name.
If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix
:param api_key: Optional API key to use for authentication.
:param router_config: Optional router configuration to use instead of the global config
"""
self.server_sessions: t.Dict[str, ServerConnection] = {}
self.capabilities_mapping: t.Dict[str, t.Dict[str, t.Any]] = defaultdict(dict)
Expand All @@ -60,6 +95,26 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = None,
if reload_server:
self.watcher = ConfigWatcher(self.profile_manager.profile_path)
self.strict: bool = strict
self.api_key = api_key
self.router_config = router_config

def create_global_config(self) -> None:
"""
Create a global configuration from the router's configuration.
This is useful if you want to initialize the router with a config
but also want that config to be available globally.
"""
if self.api_key is not None:
config_manager = ConfigManager()
# Save the API key to the global config
config_manager.save_share_config(api_key=self.api_key)

# If router_config is provided, save it to the global config
if self.router_config is not None:
host = self.router_config.get("host", DEFAULT_HOST)
port = self.router_config.get("port", DEFAULT_PORT)
share_address = self.router_config.get("share_address", DEFAULT_SHARE_ADDRESS)
config_manager.save_router_config(host, port, share_address)

def get_unique_servers(self) -> list[ServerConfig]:
profiles = self.profile_manager.list_profiles()
Expand Down Expand Up @@ -129,77 +184,89 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
# Collect server tools, prompts, and resources
if response.capabilities.tools:
tools = await client.session.list_tools() # type: ignore
for tool in tools.tools:
# To make sure tool name is unique across all servers
tool_name = tool.name
if tool_name in self.capabilities_to_server_id["tools"]:
if self.strict:
raise ValueError(
f"Tool {tool_name} already exists. Please use unique tool names across all servers."
)
else:
# Auto resolve by adding server name prefix
tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}"
self.capabilities_to_server_id["tools"][tool_name] = server_id
self.tools_mapping[tool_name] = tool
# Extract ListToolsResult from ServerResult
tools_result = tools.root
if isinstance(tools_result, types.ListToolsResult):
for tool in tools_result.tools:
# To make sure tool name is unique across all servers
tool_name = tool.name
if tool_name in self.capabilities_to_server_id["tools"]:
if self.strict:
raise ValueError(
f"Tool {tool_name} already exists. Please use unique tool names across all servers."
)
else:
# Auto resolve by adding server name prefix
tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}"
self.capabilities_to_server_id["tools"][tool_name] = server_id
self.tools_mapping[tool_name] = tool

if response.capabilities.prompts:
prompts = await client.session.list_prompts() # type: ignore
for prompt in prompts.prompts:
# To make sure prompt name is unique across all servers
prompt_name = prompt.name
if prompt_name in self.capabilities_to_server_id["prompts"]:
if self.strict:
raise ValueError(
f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers."
)
else:
# Auto resolve by adding server name prefix
prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}"
self.prompts_mapping[prompt_name] = prompt
self.capabilities_to_server_id["prompts"][prompt_name] = server_id
# Extract ListPromptsResult from ServerResult
prompts_result = prompts.root
if isinstance(prompts_result, types.ListPromptsResult):
for prompt in prompts_result.prompts:
# To make sure prompt name is unique across all servers
prompt_name = prompt.name
if prompt_name in self.capabilities_to_server_id["prompts"]:
if self.strict:
raise ValueError(
f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers."
)
else:
# Auto resolve by adding server name prefix
prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}"
self.prompts_mapping[prompt_name] = prompt
self.capabilities_to_server_id["prompts"][prompt_name] = server_id

if response.capabilities.resources:
resources = await client.session.list_resources() # type: ignore
for resource in resources.resources:
# To make sure resource URI is unique across all servers
resource_uri = resource.uri
if str(resource_uri) in self.capabilities_to_server_id["resources"]:
if self.strict:
raise ValueError(
f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers."
)
else:
# Auto resolve by adding server name prefix
host = resource_uri.host
resource_uri = AnyUrl.build(
host=f"{server_id}{RESOURCE_SPLITOR}{host}",
scheme=resource_uri.scheme,
path=resource_uri.path,
username=resource_uri.username,
password=resource_uri.password,
port=resource_uri.port,
query=resource_uri.query,
fragment=resource_uri.fragment,
)
self.resources_mapping[str(resource_uri)] = resource
self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id
# Extract ListResourcesResult from ServerResult
resources_result = resources.root
if isinstance(resources_result, types.ListResourcesResult):
for resource in resources_result.resources:
# To make sure resource URI is unique across all servers
resource_uri = resource.uri
if str(resource_uri) in self.capabilities_to_server_id["resources"]:
if self.strict:
raise ValueError(
f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers."
)
else:
# Auto resolve by adding server name prefix
host = resource_uri.host
resource_uri = AnyUrl.build(
host=f"{server_id}{RESOURCE_SPLITOR}{host}",
scheme=resource_uri.scheme,
path=resource_uri.path,
username=resource_uri.username,
password=resource_uri.password,
port=resource_uri.port,
query=resource_uri.query,
fragment=resource_uri.fragment,
)
self.resources_mapping[str(resource_uri)] = resource
self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id
resources_templates = await client.session.list_resource_templates() # type: ignore
for resource_template in resources_templates.resourceTemplates:
# To make sure resource template URI is unique across all servers
resource_template_uri_template = resource_template.uriTemplate
if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]:
if self.strict:
raise ValueError(
f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers."
)
else:
# Auto resolve by adding server name prefix
resource_template_uri_template = (
f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}"
)
self.resources_templates_mapping[resource_template_uri_template] = resource_template
self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id
# Extract ListResourceTemplatesResult from ServerResult
templates_result = resources_templates.root
if isinstance(templates_result, types.ListResourceTemplatesResult):
for resource_template in templates_result.resourceTemplates:
# To make sure resource template URI is unique across all servers
resource_template_uri_template = resource_template.uriTemplate
if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]:
if self.strict:
raise ValueError(
f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers."
)
else:
# Auto resolve by adding server name prefix
resource_template_uri_template = (
f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}"
)
self.resources_templates_mapping[resource_template_uri_template] = resource_template
self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id

async def remove_server(self, server_id: str) -> None:
"""
Expand Down Expand Up @@ -496,7 +563,8 @@ async def get_sse_server_app(
"""
await self.initialize_router()

sse = RouterSseTransport("/messages/")
# Pass the API key to the RouterSseTransport
sse = RouterSseTransport("/messages/", api_key=self.api_key)

async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
Expand Down
17 changes: 16 additions & 1 deletion src/mcpm/router/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def get_key_from_scope(scope: Scope, key_name: str) -> str | None:
class RouterSseTransport(SseServerTransport):
"""A SSE server transport that is used by the router to handle client connections."""

def __init__(self, *args, **kwargs):
def __init__(self, *args, api_key: str | None = None, **kwargs):
self._session_id_to_identifier: dict[UUID, ClientIdentifier] = {}
self.api_key = api_key
super().__init__(*args, **kwargs)

@asynccontextmanager
Expand Down Expand Up @@ -238,6 +239,20 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send):
self._session_id_to_identifier.pop(session_id, None)

def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool:
# If api_key is explicitly set to None, disable API key validation
if self.api_key is None:
logger.debug("API key validation disabled")
return True

# If we have a directly provided API key, verify it matches
if self.api_key is not None:
# If API key doesn't match, return False
if api_key != self.api_key:
logger.warning("Unauthorized API key")
return False
return True

# Otherwise, fall back to the original validation logic
try:
config_manager = ConfigManager()
host = get_key_from_scope(scope, key_name="host") or ""
Expand Down
Loading