diff --git a/.cursor/rules/pytest.mdc b/.cursor/rules/pytest.mdc new file mode 100644 index 00000000..f56a7985 --- /dev/null +++ b/.cursor/rules/pytest.mdc @@ -0,0 +1,7 @@ +--- +description: +globs: *.py +alwaysApply: false +--- +always run pytest at the end of a major change +always run ruff lint at then end of any code changes \ No newline at end of file diff --git a/src/mcpm/__init__.py b/src/mcpm/__init__.py index 5d77ba78..164c1afe 100644 --- a/src/mcpm/__init__.py +++ b/src/mcpm/__init__.py @@ -5,7 +5,9 @@ # Import version from internal module # Import router module from . import router +from .router.router import MCPRouter +from .router.router_config import RouterConfig from .version import __version__ # Define what symbols are exported from this package -__all__ = ["__version__", "router"] +__all__ = ["__version__", "router", "MCPRouter", "RouterConfig"] diff --git a/src/mcpm/commands/router.py b/src/mcpm/commands/router.py index b6a8932c..351c25d0 100644 --- a/src/mcpm/commands/router.py +++ b/src/mcpm/commands/router.py @@ -9,7 +9,6 @@ import socket import subprocess import sys -import uuid import click import psutil @@ -41,6 +40,7 @@ def is_process_running(pid): except Exception: return False + def is_port_listening(host, port) -> bool: """ Check if the specified (host, port) is being listened on. @@ -133,9 +133,12 @@ def start_router(verbose): return # get router config - config = ConfigManager().get_router_config() + config_manager = ConfigManager() + config = config_manager.get_router_config() host = config["host"] port = config["port"] + auth_enabled = config.get("auth_enabled", False) + api_key = config.get("api_key") # prepare uvicorn command uvicorn_cmd = [ @@ -185,9 +188,28 @@ def start_router(verbose): pid = process.pid write_pid_file(pid) + # Display router started information console.print(f"[bold green]MCPRouter started[/] at http://{host}:{port} (PID: {pid})") console.print(f"Log file: {log_file}") - console.print("Use 'mcpm router off' to stop the router.") + + # Display connection instructions + console.print("\n[bold cyan]Connection Information:[/]") + + api_key = api_key if auth_enabled else None + + # Show URL with or without authentication based on API key availability + if api_key: + # Show authenticated URL + console.print(f"SSE Server URL: [green]http://{host}:{port}/sse?s={api_key}[/]") + console.print("\n[bold cyan]To use a specific profile with authentication:[/]") + console.print(f"[green]http://{host}:{port}/sse?s={api_key}&profile=[/]") + else: + # Show URL without authentication + console.print(f"SSE Server URL: [green]http://{host}:{port}/sse[/]") + console.print("\n[bold cyan]To use a specific profile:[/]") + console.print(f"[green]http://{host}:{port}/sse?profile=[/]") + + console.print("\n[yellow]Use 'mcpm router off' to stop the router.[/]") except Exception as e: console.print(f"[bold red]Error:[/] Failed to start MCPRouter: {e}") @@ -197,17 +219,23 @@ def start_router(verbose): @click.option("-H", "--host", type=str, help="Host to bind the SSE server to") @click.option("-p", "--port", type=int, help="Port to bind the SSE server to") @click.option("-a", "--address", type=str, help="Remote address to share the router") +@click.option( + "--auth/--no-auth", default=True, is_flag=True, help="Enable/disable API key authentication (default: enabled)" +) +@click.option("-s", "--secret", type=str, help="Secret key for authentication") @click.help_option("-h", "--help") -def set_router_config(host, port, address): +def set_router_config(host, port, address, auth, secret: str | None = None): """Set MCPRouter global configuration. Example: mcpm router set -H localhost -p 8888 mcpm router set --host 127.0.0.1 --port 9000 + mcpm router set --no-auth # disable authentication + mcpm router set --auth # enable authentication """ - if not host and not port and not address: + if not host and not port and not address and auth is None: console.print( - "[yellow]No changes were made. Please specify at least one option (--host, --port, or --address)[/]" + "[yellow]No changes were made. Please specify at least one option (--host, --port, --address, --auth/--no-auth)[/]" ) return @@ -219,9 +247,23 @@ def set_router_config(host, port, address): host = host or current_config["host"] port = port or current_config["port"] share_address = address or current_config["share_address"] + api_key = secret + + if auth: + # Enable authentication + if api_key is None: + # Generate a new API key if authentication is enabled but no key exists + api_key = secrets.token_urlsafe(32) + console.print("[bold green]API key authentication enabled.[/] Generated new API key.") + else: + console.print("[bold green]API key authentication enabled.[/] Using provided API key.") + else: + # Disable authentication by clearing the API key + api_key = None + console.print("[bold yellow]API key authentication disabled.[/]") - # save config - if config_manager.save_router_config(host, port, share_address): + # save router config + if config_manager.save_router_config(host, port, share_address, api_key=api_key, auth_enabled=auth): console.print( f"[bold green]Router configuration updated:[/] host={host}, port={port}, share_address={share_address}" ) @@ -329,7 +371,7 @@ def router_status(): if share_config.get("pid"): if not is_process_running(share_config["pid"]): console.print("[yellow]Share link is not active, cleaning.[/]") - ConfigManager().save_share_config(share_url=None, share_pid=None, api_key=None) + ConfigManager().save_share_config(share_url=None, share_pid=None) console.print("[green]Share link cleaned[/]") else: console.print( @@ -389,17 +431,17 @@ def share(address, profile, http): tunnel = Tunnel(remote_host, remote_port, config["host"], config["port"], secrets.token_urlsafe(32), http, None) share_url = tunnel.start_tunnel() share_pid = tunnel.proc.pid if tunnel.proc else None - # generate random api key - api_key = str(uuid.uuid4()) - console.print(f"[bold green]Generated secret for share link: {api_key}[/]") + api_key = config.get("api_key") if config.get("auth_enabled") else None share_url = share_url + "/sse" # save share pid and link to config - config_manager.save_share_config(share_url, share_pid, api_key) + config_manager.save_share_config(share_url, share_pid) profile = profile or "" # print share link console.print(f"[bold green]Router is sharing at {share_url}[/]") - console.print(f"[green]Your profile can be accessed with the url {share_url}?s={api_key}&profile={profile}[/]\n") + console.print( + f"[green]Your profile can be accessed with the url {share_url}?profile={profile}{f'&s={api_key}' if api_key else ''}[/]\n" + ) console.print( "[bold yellow]Be careful about the share link, it will be exposed to the public. Make sure to share to trusted users only.[/]" ) @@ -409,17 +451,17 @@ def try_clear_share(): console.print("[bold yellow]Clearing share config...[/]") config_manager = ConfigManager() share_config = config_manager.read_share_config() - if share_config["url"]: + if share_config.get("url"): try: console.print("[bold yellow]Disabling share link...[/]") - config_manager.save_share_config(share_url=None, share_pid=None, api_key=None) + config_manager.save_share_config(share_url=None, share_pid=None) console.print("[bold green]Share link disabled[/]") - if share_config["pid"]: + if share_config.get("pid"): os.kill(share_config["pid"], signal.SIGTERM) except OSError as e: if e.errno == 3: # "No such process" console.print("[yellow]Share process does not exist, cleaning up share config...[/]") - config_manager.save_share_config(share_url=None, share_pid=None, api_key=None) + config_manager.save_share_config(share_url=None, share_pid=None) else: console.print(f"[bold red]Error:[/] Failed to stop share link: {e}") @@ -431,11 +473,11 @@ def stop_share(): # check if there is a share link already running config_manager = ConfigManager() share_config = config_manager.read_share_config() - if not share_config["url"]: + if not share_config.get("url"): console.print("[yellow]No share link is active.[/]") return - pid = share_config["pid"] + pid = share_config.get("pid") if not pid: console.print("[yellow]No share link is active.[/]") return diff --git a/src/mcpm/router/app.py b/src/mcpm/router/app.py index 27862a8e..42830018 100644 --- a/src/mcpm/router/app.py +++ b/src/mcpm/router/app.py @@ -16,6 +16,7 @@ from mcpm.monitor.event import monitor from mcpm.router.router import MCPRouter from mcpm.router.transport import RouterSseTransport +from mcpm.utils.config import ConfigManager from mcpm.utils.platform import get_log_directory LOG_DIR = get_log_directory("mcpm") @@ -30,8 +31,12 @@ ) logger = logging.getLogger("mcpm.router.daemon") +config = ConfigManager().get_router_config() +api_key = config.get("api_key") +auth_enabled = config.get("auth_enabled", False) + router = MCPRouter(reload_server=True) -sse = RouterSseTransport("/messages/") +sse = RouterSseTransport("/messages/", api_key=api_key if auth_enabled else None) class NoOpsResponse(Response): diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index 2d1085d1..908a167f 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -17,16 +17,23 @@ from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.routing import Mount, Route -from starlette.types import AppType, Lifespan +from starlette.types import Lifespan from mcpm.core.schema import ServerConfig from mcpm.monitor.base import AccessEventType from mcpm.monitor.event import trace_event from mcpm.profile.profile_config import ProfileConfigManager -from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR +from mcpm.utils.config import ( + PROMPT_SPLITOR, + RESOURCE_SPLITOR, + RESOURCE_TEMPLATE_SPLITOR, + TOOL_SPLITOR, + ConfigManager, +) from mcpm.utils.errlog_manager import ServerErrorLogManager from .client_connection import ServerConnection +from .router_config import RouterConfig from .transport import RouterSseTransport from .watcher import ConfigWatcher @@ -37,16 +44,33 @@ class MCPRouter: """ A router that aggregates multiple MCP servers (SSE/STDIO) and exposes them as a single SSE server. + + Example: + ```python + # Initialize with a custom API key + router = MCPRouter(router_config=RouterConfig(api_key="your-api-key")) + + # Initialize with custom router configuration + router_config = RouterConfig( + api_key="your-api-key", + auth_enabled=True + ) + router = MCPRouter(router_config=router_config) + ``` """ - def __init__(self, reload_server: bool = False, profile_path: str | None = None, strict: bool = False) -> None: + def __init__( + self, + reload_server: bool = False, + profile_path: str | None = None, + router_config: RouterConfig | None = None, + ) -> None: """ Initialize the router. :param reload_server: Whether to reload the server when the config changes :param profile_path: Path to the profile file - :param strict: Whether to use strict mode for duplicated tool name. - If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix + :param router_config: Optional router configuration to use instead of the global config """ self.server_sessions: t.Dict[str, ServerConnection] = {} self.capabilities_mapping: t.Dict[str, t.Dict[str, t.Any]] = defaultdict(dict) @@ -60,7 +84,10 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = None, self.watcher: Optional[ConfigWatcher] = None if reload_server: self.watcher = ConfigWatcher(self.profile_manager.profile_path) - self.strict: bool = strict + if router_config is None: + config = ConfigManager().get_router_config() + router_config = RouterConfig(api_key=config.get("api_key"), auth_enabled=config.get("auth_enabled", False)) + self.router_config = router_config self.error_log_manager = ServerErrorLogManager() def get_unique_servers(self) -> list[ServerConfig]: @@ -137,7 +164,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # To make sure tool name is unique across all servers tool_name = tool.name if tool_name in self.capabilities_to_server_id["tools"]: - if self.strict: + if self.router_config.strict: raise ValueError( f"Tool {tool_name} already exists. Please use unique tool names across all servers." ) @@ -153,7 +180,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # To make sure prompt name is unique across all servers prompt_name = prompt.name if prompt_name in self.capabilities_to_server_id["prompts"]: - if self.strict: + if self.router_config.strict: raise ValueError( f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers." ) @@ -169,7 +196,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # To make sure resource URI is unique across all servers resource_uri = resource.uri if str(resource_uri) in self.capabilities_to_server_id["resources"]: - if self.strict: + if self.router_config.strict: raise ValueError( f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers." ) @@ -186,14 +213,14 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: query=resource_uri.query, fragment=resource_uri.fragment, ) - self.resources_mapping[str(resource_uri)] = resource - self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id + self.resources_mapping[str(resource_uri)] = resource + self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id resources_templates = await client.session.list_resource_templates() # type: ignore for resource_template in resources_templates.resourceTemplates: # To make sure resource template URI is unique across all servers resource_template_uri_template = resource_template.uriTemplate if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]: - if self.strict: + if self.router_config.strict: raise ValueError( f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers." ) @@ -202,8 +229,8 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: resource_template_uri_template = ( f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}" ) - self.resources_templates_mapping[resource_template_uri_template] = resource_template - self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id + self.resources_templates_mapping[resource_template_uri_template] = resource_template + self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id async def remove_server(self, server_id: str) -> None: """ @@ -488,7 +515,7 @@ async def _initialize_server_capabilities(self): async def get_sse_server_app( self, allow_origins: t.Optional[t.List[str]] = None, include_lifespan: bool = True - ) -> AppType: + ) -> Starlette: """ Get the SSE server app. @@ -501,7 +528,9 @@ async def get_sse_server_app( """ await self.initialize_router() - sse = RouterSseTransport("/messages/") + # Pass the API key to the RouterSseTransport + api_key = None if not self.router_config.auth_enabled else self.router_config.api_key + sse = RouterSseTransport("/messages/", api_key=api_key) async def handle_sse(request: Request) -> None: async with sse.connect_sse( @@ -515,11 +544,11 @@ async def handle_sse(request: Request) -> None: self.aggregated_server.initialization_options, ) - lifespan_handler: t.Optional[Lifespan[AppType]] = None + lifespan_handler: t.Optional[Lifespan[Starlette]] = None if include_lifespan: @asynccontextmanager - async def lifespan(app: AppType): + async def lifespan(app: Starlette): yield await self.shutdown() diff --git a/src/mcpm/router/router_config.py b/src/mcpm/router/router_config.py new file mode 100644 index 00000000..93980f70 --- /dev/null +++ b/src/mcpm/router/router_config.py @@ -0,0 +1,20 @@ +from typing import Optional + +from pydantic import BaseModel, field_validator + + +class RouterConfig(BaseModel): + """ + Router configuration model for MCPRouter + """ + + strict: bool = False + api_key: Optional[str] = None + auth_enabled: bool = False + + @field_validator("api_key", mode="after") + def check_api_key(cls, v, info): + # info is ValidationInfo in pydantic v2; info.data is the dict of parsed values + if info.data.get("auth_enabled") and v is None: + raise ValueError("api_key must be provided when auth_enabled is True") + return v diff --git a/src/mcpm/router/transport.py b/src/mcpm/router/transport.py index 730e7b24..f8a277f7 100644 --- a/src/mcpm/router/transport.py +++ b/src/mcpm/router/transport.py @@ -66,8 +66,9 @@ def get_key_from_scope(scope: Scope, key_name: str) -> str | None: class RouterSseTransport(SseServerTransport): """A SSE server transport that is used by the router to handle client connections.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, api_key: str | None = None, **kwargs): self._session_id_to_identifier: dict[UUID, ClientIdentifier] = {} + self.api_key = api_key super().__init__(*args, **kwargs) @asynccontextmanager @@ -238,21 +239,29 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send): self._session_id_to_identifier.pop(session_id, None) def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool: + # If api_key is explicitly set to None, disable API key validation + if self.api_key is None: + logger.debug("API key validation disabled") + return True + + # If we have a directly provided API key, verify it matches + if api_key == self.api_key: + return True + + # At this point, self.api_key is not None but doesn't match the provided api_key + # Let's check if this is a share URL that needs special validation try: config_manager = ConfigManager() host = get_key_from_scope(scope, key_name="host") or "" if not host.startswith("http"): host = f"http://{host}" - share_config = config_manager.read_share_config() router_config = config_manager.get_router_config() host_name = urlsplit(host).hostname - share_host_name = urlsplit(share_config["url"]).hostname - if share_config["url"] and (host_name == share_host_name or host_name != router_config["host"]): - share_api_key = share_config["api_key"] - if api_key != share_api_key: - logger.warning("Unauthorized API key") + if host_name != router_config["host"]: + if api_key != self.api_key: return False except Exception as e: logger.error(f"Failed to validate API key: {e}") return False + return True diff --git a/src/mcpm/utils/config.py b/src/mcpm/utils/config.py index 02d43976..9040c1ca 100644 --- a/src/mcpm/utils/config.py +++ b/src/mcpm/utils/config.py @@ -127,7 +127,7 @@ def get_router_config(self): return router_config - def save_router_config(self, host, port, share_address): + def save_router_config(self, host, port, share_address, api_key: str | None = None, auth_enabled: bool = False): """save router configuration to config file""" router_config = self.get_config().get("router", {}) @@ -135,12 +135,14 @@ def save_router_config(self, host, port, share_address): router_config["host"] = host router_config["port"] = port router_config["share_address"] = share_address + router_config["api_key"] = api_key + router_config["auth_enabled"] = auth_enabled # save config return self.set_config("router", router_config) - def save_share_config(self, share_url: str | None = None, share_pid: int | None = None, api_key: str | None = None): - return self.set_config("share", {"url": share_url, "pid": share_pid, "api_key": api_key}) + def save_share_config(self, share_url: str | None = None, share_pid: int | None = None): + return self.set_config("share", {"url": share_url, "pid": share_pid}) def read_share_config(self) -> Dict[str, Any]: return self.get_config().get("share", {}) diff --git a/tests/test_profile.py b/tests/test_profile.py new file mode 100644 index 00000000..08a8dfb1 --- /dev/null +++ b/tests/test_profile.py @@ -0,0 +1,218 @@ +""" +Tests for the profile module +""" + +import json +import os +import tempfile +from unittest.mock import patch + +import pytest + +from mcpm.profile.profile_config import ProfileConfigManager +from mcpm.schemas.server_config import SSEServerConfig, STDIOServerConfig + + +@pytest.fixture +def temp_profile_file(): + """Create a temporary profile config file for testing""" + with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as f: + # Create a basic profile config + config = { + "test_profile": [{"name": "test-server", "type": "sse", "url": "http://localhost:8080/sse"}], + "empty_profile": [], + } + f.write(json.dumps(config).encode("utf-8")) + temp_path = f.name + + yield temp_path + # Clean up + os.unlink(temp_path) + + +@pytest.fixture +def profile_manager(temp_profile_file): + """Create a ProfileConfigManager with a temp config for testing""" + return ProfileConfigManager(profile_path=temp_profile_file) + + +def test_profile_manager_init_default_path(): + """Test that the profile manager initializes with default path""" + with patch("mcpm.profile.profile_config.os.path.exists", return_value=False): + manager = ProfileConfigManager() + assert manager.profile_path == os.path.expanduser("~/.config/mcpm/profiles.json") + assert manager._profiles == {} + + +def test_profile_manager_init_custom_path(temp_profile_file): + """Test that the profile manager initializes with a custom path""" + manager = ProfileConfigManager(profile_path=temp_profile_file) + assert manager.profile_path == temp_profile_file + assert "test_profile" in manager._profiles + assert "empty_profile" in manager._profiles + + +def test_load_profiles_not_exists(): + """Test loading profiles when file doesn't exist""" + with patch("mcpm.profile.profile_config.os.path.exists", return_value=False): + manager = ProfileConfigManager() + profiles = manager._load_profiles() + assert profiles == {} + + +def test_load_profiles(profile_manager): + """Test loading profiles from file""" + profiles = profile_manager._load_profiles() + assert "test_profile" in profiles + assert "empty_profile" in profiles + assert len(profiles["test_profile"]) == 1 + assert len(profiles["empty_profile"]) == 0 + + +def test_new_profile(profile_manager): + """Test creating a new profile""" + # Create new profile + result = profile_manager.new_profile("new_profile") + assert result is True + assert "new_profile" in profile_manager._profiles + assert profile_manager._profiles["new_profile"] == [] + + # Test creating existing profile + result = profile_manager.new_profile("test_profile") + assert result is False + + +def test_get_profile(profile_manager): + """Test getting a profile""" + # Get existing profile + profile = profile_manager.get_profile("test_profile") + assert profile is not None + assert len(profile) == 1 + assert profile[0].name == "test-server" + + # Get non-existent profile + profile = profile_manager.get_profile("non_existent") + assert profile is None + + +def test_get_profile_server(profile_manager): + """Test getting a server from a profile""" + # Get existing server + server = profile_manager.get_profile_server("test_profile", "test-server") + assert server is not None + assert server.name == "test-server" + + # Get non-existent server + server = profile_manager.get_profile_server("test_profile", "non-existent") + assert server is None + + # Get server from non-existent profile + server = profile_manager.get_profile_server("non_existent", "test-server") + assert server is None + + +def test_set_profile_new_server(profile_manager): + """Test setting a new server in a profile""" + new_server = SSEServerConfig(name="new-server", url="http://localhost:8081/sse") + result = profile_manager.set_profile("test_profile", new_server) + assert result is True + + # Verify server was added + servers = profile_manager.get_profile("test_profile") + assert len(servers) == 2 + server_names = [s.name for s in servers] + assert "new-server" in server_names + + +def test_set_profile_update_server(profile_manager): + """Test updating an existing server in a profile""" + updated_server = SSEServerConfig(name="test-server", url="http://localhost:8082/sse") + result = profile_manager.set_profile("test_profile", updated_server) + assert result is True + + # Verify server was updated + server = profile_manager.get_profile_server("test_profile", "test-server") + assert server is not None + assert server.url == "http://localhost:8082/sse" + + +def test_set_profile_new_profile(profile_manager): + """Test setting a server in a new profile""" + new_server = STDIOServerConfig(name="stdio-server", command="test-command", args=["--arg1", "--arg2"]) + result = profile_manager.set_profile("new_profile", new_server) + assert result is True + + # Verify profile and server were created + profile = profile_manager.get_profile("new_profile") + assert profile is not None + assert len(profile) == 1 + assert profile[0].name == "stdio-server" + + +def test_delete_profile(profile_manager): + """Test deleting a profile""" + # Delete existing profile + result = profile_manager.delete_profile("test_profile") + assert result is True + assert "test_profile" not in profile_manager._profiles + + # Delete non-existent profile + result = profile_manager.delete_profile("non_existent") + assert result is False + + +def test_list_profiles(profile_manager): + """Test listing all profiles""" + profiles = profile_manager.list_profiles() + assert "test_profile" in profiles + assert "empty_profile" in profiles + assert len(profiles["test_profile"]) == 1 + assert len(profiles["empty_profile"]) == 0 + + +def test_rename_profile(profile_manager): + """Test renaming a profile""" + # Rename existing profile + result = profile_manager.rename_profile("test_profile", "renamed_profile") + assert result is True + assert "test_profile" not in profile_manager._profiles + assert "renamed_profile" in profile_manager._profiles + + # Rename to existing profile name + result = profile_manager.rename_profile("renamed_profile", "empty_profile") + assert result is False + + # Rename non-existent profile + result = profile_manager.rename_profile("non_existent", "new_name") + assert result is False + + +def test_remove_server(profile_manager): + """Test removing a server from a profile""" + # Remove existing server + result = profile_manager.remove_server("test_profile", "test-server") + assert result is True + + # Verify server was removed + profile = profile_manager.get_profile("test_profile") + assert len(profile) == 0 + + # Remove non-existent server + result = profile_manager.remove_server("test_profile", "non-existent") + assert result is False + + # Remove from non-existent profile + result = profile_manager.remove_server("non_existent", "test-server") + assert result is False + + +def test_reload(profile_manager): + """Test reloading profiles from file""" + # Modify profiles + profile_manager._profiles = {} + assert len(profile_manager._profiles) == 0 + + # Reload + profile_manager.reload() + assert "test_profile" in profile_manager._profiles + assert "empty_profile" in profile_manager._profiles diff --git a/tests/test_router.py b/tests/test_router.py new file mode 100644 index 00000000..d4e80f16 --- /dev/null +++ b/tests/test_router.py @@ -0,0 +1,294 @@ +""" +Tests for the router module +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mcp import InitializeResult +from mcp.types import ListToolsResult, ServerCapabilities, Tool, ToolsCapability + +from mcpm.router.client_connection import ServerConnection +from mcpm.router.router import MCPRouter +from mcpm.router.router_config import RouterConfig +from mcpm.schemas.server_config import SSEServerConfig + + +@pytest.fixture +def mock_server_connection(): + """Create a mock server connection for testing""" + mock_conn = MagicMock(spec=ServerConnection) + mock_conn.healthy.return_value = True + mock_conn.request_for_shutdown = AsyncMock() + + # Create valid ServerCapabilities with ToolsCapability + tools_capability = ToolsCapability(listChanged=False) + capabilities = ServerCapabilities( + prompts=None, resources=None, tools=tools_capability, logging=None, experimental={} + ) + + # Mock session initialized response + mock_conn.session_initialized_response = InitializeResult( + protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "test-server", "version": "1.0.0"} + ) + + # Mock session + mock_session = AsyncMock() + # Create a valid tool with proper inputSchema structure + mock_tool = Tool(name="test-tool", description="A test tool", inputSchema={"type": "object", "properties": {}}) + # Create a ListToolsResult to be returned directly + tools_result = ListToolsResult(tools=[mock_tool]) + mock_session.list_tools = AsyncMock(return_value=tools_result) + # If you have prompts/resources, mock them similarly: + mock_session.list_prompts = AsyncMock(return_value=MagicMock(prompts=[])) + mock_session.list_resources = AsyncMock(return_value=MagicMock(resources=[])) + mock_session.list_resource_templates = AsyncMock(return_value=MagicMock(resourceTemplates=[])) + + mock_conn.session = mock_session + return mock_conn + + +@pytest.mark.asyncio +async def test_router_init(): + """Test initializing the router""" + # Test with default values + router = MCPRouter() + assert router.profile_manager is not None + assert router.watcher is None + assert router.router_config is not None + assert router.router_config.strict is False + + # Test with custom values + config = RouterConfig(api_key="test-api-key", strict=True) + router = MCPRouter( + reload_server=True, + router_config=config, + ) + + assert router.watcher is not None + assert router.router_config == config + assert router.router_config.api_key == "test-api-key" + assert router.router_config.strict is True + + +@pytest.mark.asyncio +async def test_add_server(mock_server_connection): + """Test adding a server to the router""" + router = MCPRouter() + + # Mock get_active_servers to return all server IDs + def mock_get_active_servers(_profile): + return list(router.server_sessions.keys()) + + # Patch the _patch_handler_func method to use our mock + with patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler: + mock_patch_handler.return_value.get_active_servers = mock_get_active_servers + + server_config = SSEServerConfig(name="test-server", url="http://localhost:8080/sse") + + with patch("mcpm.router.router.ServerConnection", return_value=mock_server_connection): + await router.add_server("test-server", server_config) + + # Verify server was added + assert "test-server" in router.server_sessions + assert router.server_sessions["test-server"] == mock_server_connection + + # Verify capabilities were stored + assert "test-server" in router.capabilities_mapping + + # Verify tool was stored + assert "test-tool" in router.tools_mapping + assert router.capabilities_to_server_id["tools"]["test-tool"] == "test-server" + + # Test adding duplicate server + with pytest.raises(ValueError): + await router.add_server("test-server", server_config) + + +@pytest.mark.asyncio +async def test_add_server_unhealthy(): + """Test adding an unhealthy server""" + router = MCPRouter() + server_config = SSEServerConfig(name="unhealthy-server", url="http://localhost:8080/sse") + + mock_conn = MagicMock(spec=ServerConnection) + mock_conn.healthy.return_value = False + + with patch("mcpm.router.router.ServerConnection", return_value=mock_conn): + with pytest.raises(ValueError, match="Failed to connect to server unhealthy-server"): + await router.add_server("unhealthy-server", server_config) + + +@pytest.mark.asyncio +async def test_remove_server(): + """Test removing a server from the router""" + router = MCPRouter() + + # Setup mock server session with an awaitable request_for_shutdown + mock_session = AsyncMock() + mock_session.close = AsyncMock() + + mock_server = MagicMock(spec=ServerConnection) + mock_server.session = mock_session + mock_server.request_for_shutdown = AsyncMock() + + # Mock server and capabilities + router.server_sessions = {"test-server": mock_server} + router.capabilities_mapping = {"test-server": {"tools": True}} + router.capabilities_to_server_id = {"tools": {"test-tool": "test-server"}} + router.tools_mapping = {"test-tool": MagicMock()} + + # Remove server + await router.remove_server("test-server") + + # Verify server was removed + assert "test-server" not in router.server_sessions + assert "test-server" not in router.capabilities_mapping + assert "test-tool" not in router.capabilities_to_server_id["tools"] + assert "test-tool" not in router.tools_mapping + + # Verify request_for_shutdown was called + mock_server.request_for_shutdown.assert_called_once() + + # Test removing non-existent server + with pytest.raises(ValueError, match="Server with ID non-existent does not exist"): + await router.remove_server("non-existent") + + +@pytest.mark.asyncio +async def test_update_servers(mock_server_connection): + """Test updating servers based on configuration""" + router = MCPRouter() + + # Mock get_active_servers to return all server IDs + def mock_get_active_servers(_profile): + return list(router.server_sessions.keys()) + + # Patch the _patch_handler_func method to use our mock + with patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler: + mock_patch_handler.return_value.get_active_servers = mock_get_active_servers + + # Setup initial servers with awaitable request_for_shutdown + mock_old_server = MagicMock(spec=ServerConnection) + mock_old_server.session = AsyncMock() + mock_old_server.request_for_shutdown = AsyncMock() + + router.server_sessions = {"old-server": mock_old_server} + # Initialize capabilities_mapping for the old server + router.capabilities_mapping = {"old-server": {"tools": True}} + + # Configure new servers + server_configs = [SSEServerConfig(name="test-server", url="http://localhost:8080/sse")] + + with patch("mcpm.router.router.ServerConnection", return_value=mock_server_connection): + await router.update_servers(server_configs) + + # Verify old server was removed + assert "old-server" not in router.server_sessions + mock_old_server.request_for_shutdown.assert_called_once() + + # Verify new server was added + assert "test-server" in router.server_sessions + + # Test with empty configs - should not change anything + router.server_sessions = {"test-server": mock_server_connection} + await router.update_servers([]) + assert "test-server" in router.server_sessions + + +@pytest.mark.asyncio +async def test_update_servers_error_handling(): + """Test error handling during server updates""" + router = MCPRouter() + + # Setup initial servers with awaitable request_for_shutdown + mock_old_server = MagicMock(spec=ServerConnection) + mock_old_server.session = AsyncMock() + mock_old_server.request_for_shutdown = AsyncMock() + + router.server_sessions = {"old-server": mock_old_server} + # Initialize capabilities_mapping for the old server + router.capabilities_mapping = {"old-server": {"tools": True}} + + # Configure new servers + server_configs = [SSEServerConfig(name="test-server", url="http://localhost:8080/sse")] + + # Mock add_server to raise exception + with patch.object(router, "add_server", side_effect=Exception("Test error")): + # Should not raise exception + await router.update_servers(server_configs) + + # Old server should still be removed + assert "old-server" not in router.server_sessions + mock_old_server.request_for_shutdown.assert_called_once() + + # New server should not be added + assert "test-server" not in router.server_sessions + + +@pytest.mark.asyncio +async def test_router_sse_transport_no_api_key(): + """Test RouterSseTransport with no API key (authentication disabled)""" + + from mcpm.router.transport import RouterSseTransport + + # Create a RouterSseTransport with no API key + transport = RouterSseTransport("/messages/", api_key=None) + + # Create a mock scope + mock_scope = {"type": "http"} + + # Test _validate_api_key method directly + assert transport._validate_api_key(mock_scope, api_key=None) + assert transport._validate_api_key(mock_scope, api_key="any-key") + + # Test with various API key values - all should be allowed + assert transport._validate_api_key(mock_scope, api_key="test-key") + assert transport._validate_api_key(mock_scope, api_key="invalid-key") + assert transport._validate_api_key(mock_scope, api_key="") + + +@pytest.mark.asyncio +async def test_router_sse_transport_with_api_key(): + """Test RouterSseTransport with API key (authentication enabled)""" + + from mcpm.router.transport import RouterSseTransport + + # Create a RouterSseTransport with an API key + transport = RouterSseTransport("/messages/", api_key="correct-api-key") + + # Create a mock scope + mock_scope = {"type": "http"} + + # Test _validate_api_key method directly + # With the correct API key + assert transport._validate_api_key(mock_scope, api_key="correct-api-key") + + # With an incorrect API key + assert not transport._validate_api_key(mock_scope, api_key="wrong-api-key") + + # With no API key + assert not transport._validate_api_key(mock_scope, api_key=None) + + # Test with empty string + assert not transport._validate_api_key(mock_scope, api_key="") + + +@pytest.mark.asyncio +async def test_get_sse_server_app_with_api_key(): + with patch("mcpm.router.router.RouterSseTransport") as mock_transport: + router = MCPRouter(router_config=RouterConfig(auth_enabled=True, api_key="test-api-key")) + await router.get_sse_server_app() + mock_transport.assert_called_once() + call_kwargs = mock_transport.call_args[1] + assert call_kwargs.get("api_key") == "test-api-key" + + +@pytest.mark.asyncio +async def test_get_sse_server_app_without_api_key(): + with patch("mcpm.router.router.RouterSseTransport") as mock_transport: + router = MCPRouter(router_config=RouterConfig(auth_enabled=False, api_key="custom-secret")) + await router.get_sse_server_app() + mock_transport.assert_called_once() + call_kwargs = mock_transport.call_args[1] + assert call_kwargs.get("api_key") is None