Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .cursor/rules/pytest.mdc
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
description:
globs: *.py
alwaysApply: false
---
always run pytest at the end of a major change
always run ruff lint at then end of any code changes
4 changes: 3 additions & 1 deletion src/mcpm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# Import version from internal module
# Import router module
from . import router
from .router.router import MCPRouter
from .router.router_config import RouterConfig
from .version import __version__

# Define what symbols are exported from this package
__all__ = ["__version__", "router"]
__all__ = ["__version__", "router", "MCPRouter", "RouterConfig"]
82 changes: 62 additions & 20 deletions src/mcpm/commands/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import socket
import subprocess
import sys
import uuid

import click
import psutil
Expand Down Expand Up @@ -41,6 +40,7 @@ def is_process_running(pid):
except Exception:
return False


def is_port_listening(host, port) -> bool:
"""
Check if the specified (host, port) is being listened on.
Expand Down Expand Up @@ -133,9 +133,12 @@ def start_router(verbose):
return

# get router config
config = ConfigManager().get_router_config()
config_manager = ConfigManager()
config = config_manager.get_router_config()
host = config["host"]
port = config["port"]
auth_enabled = config.get("auth_enabled", False)
api_key = config.get("api_key")

# prepare uvicorn command
uvicorn_cmd = [
Expand Down Expand Up @@ -185,9 +188,28 @@ def start_router(verbose):
pid = process.pid
write_pid_file(pid)

# Display router started information
console.print(f"[bold green]MCPRouter started[/] at http://{host}:{port} (PID: {pid})")
console.print(f"Log file: {log_file}")
console.print("Use 'mcpm router off' to stop the router.")

# Display connection instructions
console.print("\n[bold cyan]Connection Information:[/]")

api_key = api_key if auth_enabled else None

# Show URL with or without authentication based on API key availability
if api_key:
# Show authenticated URL
console.print(f"SSE Server URL: [green]http://{host}:{port}/sse?s={api_key}[/]")
console.print("\n[bold cyan]To use a specific profile with authentication:[/]")
console.print(f"[green]http://{host}:{port}/sse?s={api_key}&profile=<profile_name>[/]")
else:
# Show URL without authentication
console.print(f"SSE Server URL: [green]http://{host}:{port}/sse[/]")
console.print("\n[bold cyan]To use a specific profile:[/]")
console.print(f"[green]http://{host}:{port}/sse?profile=<profile_name>[/]")

console.print("\n[yellow]Use 'mcpm router off' to stop the router.[/]")

except Exception as e:
console.print(f"[bold red]Error:[/] Failed to start MCPRouter: {e}")
Expand All @@ -197,17 +219,23 @@ def start_router(verbose):
@click.option("-H", "--host", type=str, help="Host to bind the SSE server to")
@click.option("-p", "--port", type=int, help="Port to bind the SSE server to")
@click.option("-a", "--address", type=str, help="Remote address to share the router")
@click.option(
"--auth/--no-auth", default=True, is_flag=True, help="Enable/disable API key authentication (default: enabled)"
)
@click.option("-s", "--secret", type=str, help="Secret key for authentication")
@click.help_option("-h", "--help")
def set_router_config(host, port, address):
def set_router_config(host, port, address, auth, secret: str | None = None):
"""Set MCPRouter global configuration.

Example:
mcpm router set -H localhost -p 8888
mcpm router set --host 127.0.0.1 --port 9000
mcpm router set --no-auth # disable authentication
mcpm router set --auth # enable authentication
"""
if not host and not port and not address:
if not host and not port and not address and auth is None:
console.print(
"[yellow]No changes were made. Please specify at least one option (--host, --port, or --address)[/]"
"[yellow]No changes were made. Please specify at least one option (--host, --port, --address, --auth/--no-auth)[/]"
)
return

Expand All @@ -219,9 +247,23 @@ def set_router_config(host, port, address):
host = host or current_config["host"]
port = port or current_config["port"]
share_address = address or current_config["share_address"]
api_key = secret

if auth:
# Enable authentication
if api_key is None:
# Generate a new API key if authentication is enabled but no key exists
api_key = secrets.token_urlsafe(32)
console.print("[bold green]API key authentication enabled.[/] Generated new API key.")
else:
console.print("[bold green]API key authentication enabled.[/] Using provided API key.")
else:
# Disable authentication by clearing the API key
api_key = None
console.print("[bold yellow]API key authentication disabled.[/]")

# save config
if config_manager.save_router_config(host, port, share_address):
# save router config
if config_manager.save_router_config(host, port, share_address, api_key=api_key, auth_enabled=auth):
console.print(
f"[bold green]Router configuration updated:[/] host={host}, port={port}, share_address={share_address}"
)
Expand Down Expand Up @@ -329,7 +371,7 @@ def router_status():
if share_config.get("pid"):
if not is_process_running(share_config["pid"]):
console.print("[yellow]Share link is not active, cleaning.[/]")
ConfigManager().save_share_config(share_url=None, share_pid=None, api_key=None)
ConfigManager().save_share_config(share_url=None, share_pid=None)
console.print("[green]Share link cleaned[/]")
else:
console.print(
Expand Down Expand Up @@ -389,17 +431,17 @@ def share(address, profile, http):
tunnel = Tunnel(remote_host, remote_port, config["host"], config["port"], secrets.token_urlsafe(32), http, None)
share_url = tunnel.start_tunnel()
share_pid = tunnel.proc.pid if tunnel.proc else None
# generate random api key
api_key = str(uuid.uuid4())
console.print(f"[bold green]Generated secret for share link: {api_key}[/]")
api_key = config.get("api_key") if config.get("auth_enabled") else None
share_url = share_url + "/sse"
# save share pid and link to config
config_manager.save_share_config(share_url, share_pid, api_key)
config_manager.save_share_config(share_url, share_pid)
profile = profile or "<your_profile>"

# print share link
console.print(f"[bold green]Router is sharing at {share_url}[/]")
console.print(f"[green]Your profile can be accessed with the url {share_url}?s={api_key}&profile={profile}[/]\n")
console.print(
f"[green]Your profile can be accessed with the url {share_url}?profile={profile}{f'&s={api_key}' if api_key else ''}[/]\n"
)
console.print(
"[bold yellow]Be careful about the share link, it will be exposed to the public. Make sure to share to trusted users only.[/]"
)
Expand All @@ -409,17 +451,17 @@ def try_clear_share():
console.print("[bold yellow]Clearing share config...[/]")
config_manager = ConfigManager()
share_config = config_manager.read_share_config()
if share_config["url"]:
if share_config.get("url"):
try:
console.print("[bold yellow]Disabling share link...[/]")
config_manager.save_share_config(share_url=None, share_pid=None, api_key=None)
config_manager.save_share_config(share_url=None, share_pid=None)
console.print("[bold green]Share link disabled[/]")
if share_config["pid"]:
if share_config.get("pid"):
os.kill(share_config["pid"], signal.SIGTERM)
except OSError as e:
if e.errno == 3: # "No such process"
console.print("[yellow]Share process does not exist, cleaning up share config...[/]")
config_manager.save_share_config(share_url=None, share_pid=None, api_key=None)
config_manager.save_share_config(share_url=None, share_pid=None)
else:
console.print(f"[bold red]Error:[/] Failed to stop share link: {e}")

Expand All @@ -431,11 +473,11 @@ def stop_share():
# check if there is a share link already running
config_manager = ConfigManager()
share_config = config_manager.read_share_config()
if not share_config["url"]:
if not share_config.get("url"):
console.print("[yellow]No share link is active.[/]")
return

pid = share_config["pid"]
pid = share_config.get("pid")
if not pid:
console.print("[yellow]No share link is active.[/]")
return
Expand Down
7 changes: 6 additions & 1 deletion src/mcpm/router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mcpm.monitor.event import monitor
from mcpm.router.router import MCPRouter
from mcpm.router.transport import RouterSseTransport
from mcpm.utils.config import ConfigManager
from mcpm.utils.platform import get_log_directory

LOG_DIR = get_log_directory("mcpm")
Expand All @@ -30,8 +31,12 @@
)
logger = logging.getLogger("mcpm.router.daemon")

config = ConfigManager().get_router_config()
api_key = config.get("api_key")
auth_enabled = config.get("auth_enabled", False)

router = MCPRouter(reload_server=True)
sse = RouterSseTransport("/messages/")
sse = RouterSseTransport("/messages/", api_key=api_key if auth_enabled else None)


class NoOpsResponse(Response):
Expand Down
65 changes: 47 additions & 18 deletions src/mcpm/router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.routing import Mount, Route
from starlette.types import AppType, Lifespan
from starlette.types import Lifespan

from mcpm.core.schema import ServerConfig
from mcpm.monitor.base import AccessEventType
from mcpm.monitor.event import trace_event
from mcpm.profile.profile_config import ProfileConfigManager
from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR
from mcpm.utils.config import (
PROMPT_SPLITOR,
RESOURCE_SPLITOR,
RESOURCE_TEMPLATE_SPLITOR,
TOOL_SPLITOR,
ConfigManager,
)
from mcpm.utils.errlog_manager import ServerErrorLogManager

from .client_connection import ServerConnection
from .router_config import RouterConfig
from .transport import RouterSseTransport
from .watcher import ConfigWatcher

Expand All @@ -37,16 +44,33 @@ class MCPRouter:
"""
A router that aggregates multiple MCP servers (SSE/STDIO) and
exposes them as a single SSE server.

Example:
```python
# Initialize with a custom API key
router = MCPRouter(router_config=RouterConfig(api_key="your-api-key"))

# Initialize with custom router configuration
router_config = RouterConfig(
api_key="your-api-key",
auth_enabled=True
)
router = MCPRouter(router_config=router_config)
```
"""

def __init__(self, reload_server: bool = False, profile_path: str | None = None, strict: bool = False) -> None:
def __init__(
self,
reload_server: bool = False,
profile_path: str | None = None,
router_config: RouterConfig | None = None,
) -> None:
"""
Initialize the router.

:param reload_server: Whether to reload the server when the config changes
:param profile_path: Path to the profile file
:param strict: Whether to use strict mode for duplicated tool name.
If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix
:param router_config: Optional router configuration to use instead of the global config
"""
self.server_sessions: t.Dict[str, ServerConnection] = {}
self.capabilities_mapping: t.Dict[str, t.Dict[str, t.Any]] = defaultdict(dict)
Expand All @@ -60,7 +84,10 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = None,
self.watcher: Optional[ConfigWatcher] = None
if reload_server:
self.watcher = ConfigWatcher(self.profile_manager.profile_path)
self.strict: bool = strict
if router_config is None:
config = ConfigManager().get_router_config()
router_config = RouterConfig(api_key=config.get("api_key"), auth_enabled=config.get("auth_enabled", False))
self.router_config = router_config
self.error_log_manager = ServerErrorLogManager()

def get_unique_servers(self) -> list[ServerConfig]:
Expand Down Expand Up @@ -137,7 +164,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
# To make sure tool name is unique across all servers
tool_name = tool.name
if tool_name in self.capabilities_to_server_id["tools"]:
if self.strict:
if self.router_config.strict:
raise ValueError(
f"Tool {tool_name} already exists. Please use unique tool names across all servers."
)
Expand All @@ -153,7 +180,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
# To make sure prompt name is unique across all servers
prompt_name = prompt.name
if prompt_name in self.capabilities_to_server_id["prompts"]:
if self.strict:
if self.router_config.strict:
raise ValueError(
f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers."
)
Expand All @@ -169,7 +196,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
# To make sure resource URI is unique across all servers
resource_uri = resource.uri
if str(resource_uri) in self.capabilities_to_server_id["resources"]:
if self.strict:
if self.router_config.strict:
raise ValueError(
f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers."
)
Expand All @@ -186,14 +213,14 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
query=resource_uri.query,
fragment=resource_uri.fragment,
)
self.resources_mapping[str(resource_uri)] = resource
self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id
self.resources_mapping[str(resource_uri)] = resource
self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id
resources_templates = await client.session.list_resource_templates() # type: ignore
for resource_template in resources_templates.resourceTemplates:
# To make sure resource template URI is unique across all servers
resource_template_uri_template = resource_template.uriTemplate
if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]:
if self.strict:
if self.router_config.strict:
raise ValueError(
f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers."
)
Expand All @@ -202,8 +229,8 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
resource_template_uri_template = (
f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}"
)
self.resources_templates_mapping[resource_template_uri_template] = resource_template
self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id
self.resources_templates_mapping[resource_template_uri_template] = resource_template
self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id

async def remove_server(self, server_id: str) -> None:
"""
Expand Down Expand Up @@ -488,7 +515,7 @@ async def _initialize_server_capabilities(self):

async def get_sse_server_app(
self, allow_origins: t.Optional[t.List[str]] = None, include_lifespan: bool = True
) -> AppType:
) -> Starlette:
"""
Get the SSE server app.

Expand All @@ -501,7 +528,9 @@ async def get_sse_server_app(
"""
await self.initialize_router()

sse = RouterSseTransport("/messages/")
# Pass the API key to the RouterSseTransport
api_key = None if not self.router_config.auth_enabled else self.router_config.api_key
sse = RouterSseTransport("/messages/", api_key=api_key)

async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
Expand All @@ -515,11 +544,11 @@ async def handle_sse(request: Request) -> None:
self.aggregated_server.initialization_options,
)

lifespan_handler: t.Optional[Lifespan[AppType]] = None
lifespan_handler: t.Optional[Lifespan[Starlette]] = None
if include_lifespan:

@asynccontextmanager
async def lifespan(app: AppType):
async def lifespan(app: Starlette):
yield
await self.shutdown()

Expand Down
Loading