From f447d3b31c85fd5c16194860c12a7f345c71fb6b Mon Sep 17 00:00:00 2001 From: openhands Date: Tue, 22 Apr 2025 19:37:41 +0000 Subject: [PATCH 01/14] Add support for custom API key and router configuration in MCPRouter --- README.md | 24 ++++++++++ examples/custom_api_key_example.py | 28 ++++++++++++ examples/custom_router_config_example.py | 46 +++++++++++++++++++ src/mcpm/router/router.py | 56 ++++++++++++++++++++++-- src/mcpm/router/transport.py | 8 +++- 5 files changed, 158 insertions(+), 4 deletions(-) create mode 100644 examples/custom_api_key_example.py create mode 100644 examples/custom_router_config_example.py diff --git a/README.md b/README.md index 3a4ecd51..47f64355 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,30 @@ mcpm router share # Share the router to public mcpm router unshare # Unshare the router ``` +#### Programmatic Usage of MCPRouter + +You can also use the `MCPRouter` class programmatically in your Python applications. This is especially useful if you want to integrate MCPM into your own applications or scripts without relying on the global configuration. + +```python +from mcpm.router.router import MCPRouter + +# Initialize with a custom API key +router = MCPRouter(api_key="your-custom-api-key") + +# Initialize with custom router configuration +router_config = { + "host": "localhost", + "port": 8080, + "share_address": "custom.share.address:8080" +} +router = MCPRouter(api_key="your-custom-api-key", router_config=router_config) + +# Optionally, create a global config from the router's configuration +router.create_global_config() +``` + +See the [examples directory](examples/) for more detailed examples of programmatic usage. + ### 🛠️ Utilities (`util`) ```bash diff --git a/examples/custom_api_key_example.py b/examples/custom_api_key_example.py new file mode 100644 index 00000000..0af82c0d --- /dev/null +++ b/examples/custom_api_key_example.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +""" +Example script demonstrating how to use MCPRouter with a custom API key. +""" + +import asyncio +from mcpm.router.router import MCPRouter + +async def main(): + # Initialize the router with a custom API key + router = MCPRouter(api_key="your-custom-api-key") + + # Optionally, create a global config from the router's configuration + # This will save the API key to the global config file + # router.create_global_config() + + # Initialize the router and start the server + app = await router.get_sse_server_app(allow_origins=["*"]) + + # Print a message to indicate that the router is ready + print("Router initialized with custom API key") + print("You can now use this router without setting up a global config") + + # In a real application, you would start the server here: + # await router.start_sse_server(host="localhost", port=8080, allow_origins=["*"]) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/custom_router_config_example.py b/examples/custom_router_config_example.py new file mode 100644 index 00000000..7830bf47 --- /dev/null +++ b/examples/custom_router_config_example.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +""" +Example script demonstrating how to use MCPRouter with custom configuration. +""" + +import asyncio +from mcpm.router.router import MCPRouter + +async def main(): + # Define custom router configuration + router_config = { + "host": "localhost", + "port": 8080, + "share_address": "custom.share.address:8080" + } + + # Initialize the router with a custom API key and router configuration + router = MCPRouter( + api_key="your-custom-api-key", + router_config=router_config, + # You can also specify other parameters: + # reload_server=True, # Reload the server when the config changes + # profile_path="/custom/path/to/profile.json", # Custom profile path + # strict=True, # Use strict mode for duplicated tool names + ) + + # Optionally, create a global config from the router's configuration + # This will save both the API key and router configuration to the global config file + # router.create_global_config() + + # Initialize the router and start the server + app = await router.get_sse_server_app(allow_origins=["*"]) + + # Print a message to indicate that the router is ready + print("Router initialized with custom configuration") + print("You can now use this router without setting up a global config") + + # In a real application, you would start the server here: + # await router.start_sse_server( + # host=router_config["host"], + # port=router_config["port"], + # allow_origins=["*"] + # ) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index 8298ef45..02d47497 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -23,7 +23,10 @@ from mcpm.monitor.event import trace_event from mcpm.profile.profile_config import ProfileConfigManager from mcpm.schemas.server_config import ServerConfig -from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR +from mcpm.utils.config import ( + PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR, + ConfigManager, DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SHARE_ADDRESS +) from .client_connection import ServerConnection from .transport import RouterSseTransport @@ -36,9 +39,33 @@ class MCPRouter: """ A router that aggregates multiple MCP servers (SSE/STDIO) and exposes them as a single SSE server. + + Example: + ```python + # Initialize with a custom API key + router = MCPRouter(api_key="your-api-key") + + # Initialize with custom router configuration + router_config = { + "host": "localhost", + "port": 8080, + "share_address": "custom.share.address:8080" + } + router = MCPRouter(api_key="your-api-key", router_config=router_config) + + # Create a global config from the router's configuration + router.create_global_config() + ``` """ - def __init__(self, reload_server: bool = False, profile_path: str | None = None, strict: bool = False) -> None: + def __init__( + self, + reload_server: bool = False, + profile_path: str | None = None, + strict: bool = False, + api_key: str | None = None, + router_config: dict | None = None + ) -> None: """ Initialize the router. @@ -46,6 +73,8 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = None, :param profile_path: Path to the profile file :param strict: Whether to use strict mode for duplicated tool name. If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix + :param api_key: Optional API key to use for authentication + :param router_config: Optional router configuration to use instead of the global config """ self.server_sessions: t.Dict[str, ServerConnection] = {} self.capabilities_mapping: t.Dict[str, t.Dict[str, t.Any]] = defaultdict(dict) @@ -60,6 +89,26 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = None, if reload_server: self.watcher = ConfigWatcher(self.profile_manager.profile_path) self.strict: bool = strict + self.api_key = api_key + self.router_config = router_config + + def create_global_config(self) -> None: + """ + Create a global configuration from the router's configuration. + This is useful if you want to initialize the router with a config + but also want that config to be available globally. + """ + if self.api_key is not None: + config_manager = ConfigManager() + # Save the API key to the global config + config_manager.save_share_config(api_key=self.api_key) + + # If router_config is provided, save it to the global config + if self.router_config is not None: + host = self.router_config.get("host", DEFAULT_HOST) + port = self.router_config.get("port", DEFAULT_PORT) + share_address = self.router_config.get("share_address", DEFAULT_SHARE_ADDRESS) + config_manager.save_router_config(host, port, share_address) def get_unique_servers(self) -> list[ServerConfig]: profiles = self.profile_manager.list_profiles() @@ -496,7 +545,8 @@ async def get_sse_server_app( """ await self.initialize_router() - sse = RouterSseTransport("/messages/") + # Pass the API key to the RouterSseTransport + sse = RouterSseTransport("/messages/", api_key=self.api_key) async def handle_sse(request: Request) -> None: async with sse.connect_sse( diff --git a/src/mcpm/router/transport.py b/src/mcpm/router/transport.py index 730e7b24..b321e352 100644 --- a/src/mcpm/router/transport.py +++ b/src/mcpm/router/transport.py @@ -66,8 +66,9 @@ def get_key_from_scope(scope: Scope, key_name: str) -> str | None: class RouterSseTransport(SseServerTransport): """A SSE server transport that is used by the router to handle client connections.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, api_key: str | None = None, **kwargs): self._session_id_to_identifier: dict[UUID, ClientIdentifier] = {} + self.api_key = api_key super().__init__(*args, **kwargs) @asynccontextmanager @@ -238,6 +239,11 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send): self._session_id_to_identifier.pop(session_id, None) def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool: + # If we have a directly provided API key and it matches the request's API key, return True + if self.api_key is not None and api_key == self.api_key: + return True + + # Otherwise, fall back to the original validation logic try: config_manager = ConfigManager() host = get_key_from_scope(scope, key_name="host") or "" From f7bebc4f1159c5f6625e961a89e3d827f023eaf8 Mon Sep 17 00:00:00 2001 From: openhands Date: Tue, 22 Apr 2025 19:39:25 +0000 Subject: [PATCH 02/14] Remove example scripts and update README --- README.md | 2 -- examples/custom_api_key_example.py | 28 --------------- examples/custom_router_config_example.py | 46 ------------------------ 3 files changed, 76 deletions(-) delete mode 100644 examples/custom_api_key_example.py delete mode 100644 examples/custom_router_config_example.py diff --git a/README.md b/README.md index 47f64355..e9405a49 100644 --- a/README.md +++ b/README.md @@ -184,8 +184,6 @@ router = MCPRouter(api_key="your-custom-api-key", router_config=router_config) router.create_global_config() ``` -See the [examples directory](examples/) for more detailed examples of programmatic usage. - ### 🛠️ Utilities (`util`) ```bash diff --git a/examples/custom_api_key_example.py b/examples/custom_api_key_example.py deleted file mode 100644 index 0af82c0d..00000000 --- a/examples/custom_api_key_example.py +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env python -""" -Example script demonstrating how to use MCPRouter with a custom API key. -""" - -import asyncio -from mcpm.router.router import MCPRouter - -async def main(): - # Initialize the router with a custom API key - router = MCPRouter(api_key="your-custom-api-key") - - # Optionally, create a global config from the router's configuration - # This will save the API key to the global config file - # router.create_global_config() - - # Initialize the router and start the server - app = await router.get_sse_server_app(allow_origins=["*"]) - - # Print a message to indicate that the router is ready - print("Router initialized with custom API key") - print("You can now use this router without setting up a global config") - - # In a real application, you would start the server here: - # await router.start_sse_server(host="localhost", port=8080, allow_origins=["*"]) - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/examples/custom_router_config_example.py b/examples/custom_router_config_example.py deleted file mode 100644 index 7830bf47..00000000 --- a/examples/custom_router_config_example.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python -""" -Example script demonstrating how to use MCPRouter with custom configuration. -""" - -import asyncio -from mcpm.router.router import MCPRouter - -async def main(): - # Define custom router configuration - router_config = { - "host": "localhost", - "port": 8080, - "share_address": "custom.share.address:8080" - } - - # Initialize the router with a custom API key and router configuration - router = MCPRouter( - api_key="your-custom-api-key", - router_config=router_config, - # You can also specify other parameters: - # reload_server=True, # Reload the server when the config changes - # profile_path="/custom/path/to/profile.json", # Custom profile path - # strict=True, # Use strict mode for duplicated tool names - ) - - # Optionally, create a global config from the router's configuration - # This will save both the API key and router configuration to the global config file - # router.create_global_config() - - # Initialize the router and start the server - app = await router.get_sse_server_app(allow_origins=["*"]) - - # Print a message to indicate that the router is ready - print("Router initialized with custom configuration") - print("You can now use this router without setting up a global config") - - # In a real application, you would start the server here: - # await router.start_sse_server( - # host=router_config["host"], - # port=router_config["port"], - # allow_origins=["*"] - # ) - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file From a719bef3bce5ca344271be9e9ed4c23a17f6eb0a Mon Sep 17 00:00:00 2001 From: openhands Date: Tue, 22 Apr 2025 22:20:59 +0000 Subject: [PATCH 03/14] Add support for disabling API key validation when api_key is set to None --- README.md | 3 ++ src/mcpm/router/router.py | 5 ++- src/mcpm/router/transport.py | 5 +++ tests/router/__init__.py | 0 tests/router/test_api_key_disabled.py | 47 +++++++++++++++++++++++++++ 5 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 tests/router/__init__.py create mode 100644 tests/router/test_api_key_disabled.py diff --git a/README.md b/README.md index e9405a49..4ab10baa 100644 --- a/README.md +++ b/README.md @@ -180,6 +180,9 @@ router_config = { } router = MCPRouter(api_key="your-custom-api-key", router_config=router_config) +# Disable API key validation by setting api_key to None +router = MCPRouter(api_key=None) + # Optionally, create a global config from the router's configuration router.create_global_config() ``` diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index 02d47497..d3a0fb75 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -53,6 +53,9 @@ class MCPRouter: } router = MCPRouter(api_key="your-api-key", router_config=router_config) + # Disable API key validation by setting api_key to None + router = MCPRouter(api_key=None) + # Create a global config from the router's configuration router.create_global_config() ``` @@ -73,7 +76,7 @@ def __init__( :param profile_path: Path to the profile file :param strict: Whether to use strict mode for duplicated tool name. If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix - :param api_key: Optional API key to use for authentication + :param api_key: Optional API key to use for authentication. Set to None to disable API key validation. :param router_config: Optional router configuration to use instead of the global config """ self.server_sessions: t.Dict[str, ServerConnection] = {} diff --git a/src/mcpm/router/transport.py b/src/mcpm/router/transport.py index b321e352..502e0ee7 100644 --- a/src/mcpm/router/transport.py +++ b/src/mcpm/router/transport.py @@ -239,6 +239,11 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send): self._session_id_to_identifier.pop(session_id, None) def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool: + # If api_key is explicitly set to None, disable API key validation + if self.api_key is None: + logger.debug("API key validation disabled") + return True + # If we have a directly provided API key and it matches the request's API key, return True if self.api_key is not None and api_key == self.api_key: return True diff --git a/tests/router/__init__.py b/tests/router/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/router/test_api_key_disabled.py b/tests/router/test_api_key_disabled.py new file mode 100644 index 00000000..f0fbc7e1 --- /dev/null +++ b/tests/router/test_api_key_disabled.py @@ -0,0 +1,47 @@ +import unittest +from unittest.mock import MagicMock, patch + +from mcpm.router.transport import RouterSseTransport + + +class TestApiKeyDisabled(unittest.TestCase): + """Test that API key validation is disabled when api_key is set to None.""" + + def test_api_key_disabled(self): + """Test that API key validation is disabled when api_key is set to None.""" + # Create a transport with api_key=None + transport = RouterSseTransport("/messages/", api_key=None) + + # Mock the scope + scope = MagicMock() + + # Test that _validate_api_key returns True regardless of the api_key parameter + self.assertTrue(transport._validate_api_key(scope, api_key=None)) + self.assertTrue(transport._validate_api_key(scope, api_key="some-key")) + self.assertTrue(transport._validate_api_key(scope, api_key="invalid-key")) + + def test_api_key_enabled(self): + """Test that API key validation works when api_key is set.""" + # Create a transport with a specific api_key + transport = RouterSseTransport("/messages/", api_key="test-key") + + # Mock the scope + scope = MagicMock() + + # Test that _validate_api_key returns True only for the matching key + self.assertTrue(transport._validate_api_key(scope, api_key="test-key")) + self.assertFalse(transport._validate_api_key(scope, api_key="wrong-key")) + + # When using the default validation logic, we need to mock the ConfigManager + with patch("mcpm.router.transport.ConfigManager") as mock_config_manager: + # Set up the mock to make the default validation logic fail + mock_instance = mock_config_manager.return_value + mock_instance.read_share_config.return_value = {"url": "http://example.com", "api_key": "share-key"} + mock_instance.get_router_config.return_value = {"host": "localhost"} + + # Test with a key that doesn't match the transport's key + self.assertFalse(transport._validate_api_key(scope, api_key="wrong-key")) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 3c6ff3753846ec808d40307c5f94ce631b29f29a Mon Sep 17 00:00:00 2001 From: openhands Date: Tue, 22 Apr 2025 22:38:04 +0000 Subject: [PATCH 04/14] Remove test files and revert README changes --- tests/router/__init__.py | 0 tests/router/test_api_key_disabled.py | 47 --------------------------- 2 files changed, 47 deletions(-) delete mode 100644 tests/router/__init__.py delete mode 100644 tests/router/test_api_key_disabled.py diff --git a/tests/router/__init__.py b/tests/router/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/router/test_api_key_disabled.py b/tests/router/test_api_key_disabled.py deleted file mode 100644 index f0fbc7e1..00000000 --- a/tests/router/test_api_key_disabled.py +++ /dev/null @@ -1,47 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from mcpm.router.transport import RouterSseTransport - - -class TestApiKeyDisabled(unittest.TestCase): - """Test that API key validation is disabled when api_key is set to None.""" - - def test_api_key_disabled(self): - """Test that API key validation is disabled when api_key is set to None.""" - # Create a transport with api_key=None - transport = RouterSseTransport("/messages/", api_key=None) - - # Mock the scope - scope = MagicMock() - - # Test that _validate_api_key returns True regardless of the api_key parameter - self.assertTrue(transport._validate_api_key(scope, api_key=None)) - self.assertTrue(transport._validate_api_key(scope, api_key="some-key")) - self.assertTrue(transport._validate_api_key(scope, api_key="invalid-key")) - - def test_api_key_enabled(self): - """Test that API key validation works when api_key is set.""" - # Create a transport with a specific api_key - transport = RouterSseTransport("/messages/", api_key="test-key") - - # Mock the scope - scope = MagicMock() - - # Test that _validate_api_key returns True only for the matching key - self.assertTrue(transport._validate_api_key(scope, api_key="test-key")) - self.assertFalse(transport._validate_api_key(scope, api_key="wrong-key")) - - # When using the default validation logic, we need to mock the ConfigManager - with patch("mcpm.router.transport.ConfigManager") as mock_config_manager: - # Set up the mock to make the default validation logic fail - mock_instance = mock_config_manager.return_value - mock_instance.read_share_config.return_value = {"url": "http://example.com", "api_key": "share-key"} - mock_instance.get_router_config.return_value = {"host": "localhost"} - - # Test with a key that doesn't match the transport's key - self.assertFalse(transport._validate_api_key(scope, api_key="wrong-key")) - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file From 96cdd8dc60beaa5eb4d20a06f3463840d00ef61b Mon Sep 17 00:00:00 2001 From: openhands Date: Tue, 22 Apr 2025 22:38:40 +0000 Subject: [PATCH 05/14] Revert docstring changes in router.py --- src/mcpm/router/router.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index d3a0fb75..5f3b92a0 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -53,9 +53,6 @@ class MCPRouter: } router = MCPRouter(api_key="your-api-key", router_config=router_config) - # Disable API key validation by setting api_key to None - router = MCPRouter(api_key=None) - # Create a global config from the router's configuration router.create_global_config() ``` @@ -76,7 +73,7 @@ def __init__( :param profile_path: Path to the profile file :param strict: Whether to use strict mode for duplicated tool name. If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix - :param api_key: Optional API key to use for authentication. Set to None to disable API key validation. + :param api_key: Optional API key to use for authentication. :param router_config: Optional router configuration to use instead of the global config """ self.server_sessions: t.Dict[str, ServerConnection] = {} From 70c8a54e04a368e51311db2f36d903548fdc0531 Mon Sep 17 00:00:00 2001 From: openhands Date: Tue, 22 Apr 2025 22:38:55 +0000 Subject: [PATCH 06/14] Revert README changes --- README.md | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/README.md b/README.md index 4ab10baa..3a4ecd51 100644 --- a/README.md +++ b/README.md @@ -162,31 +162,6 @@ mcpm router share # Share the router to public mcpm router unshare # Unshare the router ``` -#### Programmatic Usage of MCPRouter - -You can also use the `MCPRouter` class programmatically in your Python applications. This is especially useful if you want to integrate MCPM into your own applications or scripts without relying on the global configuration. - -```python -from mcpm.router.router import MCPRouter - -# Initialize with a custom API key -router = MCPRouter(api_key="your-custom-api-key") - -# Initialize with custom router configuration -router_config = { - "host": "localhost", - "port": 8080, - "share_address": "custom.share.address:8080" -} -router = MCPRouter(api_key="your-custom-api-key", router_config=router_config) - -# Disable API key validation by setting api_key to None -router = MCPRouter(api_key=None) - -# Optionally, create a global config from the router's configuration -router.create_global_config() -``` - ### 🛠️ Utilities (`util`) ```bash From 01898892db64daf1410b5e42af712c59849112e3 Mon Sep 17 00:00:00 2001 From: openhands Date: Tue, 22 Apr 2025 22:43:15 +0000 Subject: [PATCH 07/14] Fix linting issues with Ruff --- src/mcpm/router/router.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index 5f3b92a0..546a8438 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -24,8 +24,14 @@ from mcpm.profile.profile_config import ProfileConfigManager from mcpm.schemas.server_config import ServerConfig from mcpm.utils.config import ( - PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR, - ConfigManager, DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SHARE_ADDRESS + DEFAULT_HOST, + DEFAULT_PORT, + DEFAULT_SHARE_ADDRESS, + PROMPT_SPLITOR, + RESOURCE_SPLITOR, + RESOURCE_TEMPLATE_SPLITOR, + TOOL_SPLITOR, + ConfigManager, ) from .client_connection import ServerConnection @@ -39,12 +45,12 @@ class MCPRouter: """ A router that aggregates multiple MCP servers (SSE/STDIO) and exposes them as a single SSE server. - + Example: ```python # Initialize with a custom API key router = MCPRouter(api_key="your-api-key") - + # Initialize with custom router configuration router_config = { "host": "localhost", @@ -52,16 +58,16 @@ class MCPRouter: "share_address": "custom.share.address:8080" } router = MCPRouter(api_key="your-api-key", router_config=router_config) - + # Create a global config from the router's configuration router.create_global_config() ``` """ def __init__( - self, - reload_server: bool = False, - profile_path: str | None = None, + self, + reload_server: bool = False, + profile_path: str | None = None, strict: bool = False, api_key: str | None = None, router_config: dict | None = None @@ -91,7 +97,7 @@ def __init__( self.strict: bool = strict self.api_key = api_key self.router_config = router_config - + def create_global_config(self) -> None: """ Create a global configuration from the router's configuration. @@ -102,7 +108,7 @@ def create_global_config(self) -> None: config_manager = ConfigManager() # Save the API key to the global config config_manager.save_share_config(api_key=self.api_key) - + # If router_config is provided, save it to the global config if self.router_config is not None: host = self.router_config.get("host", DEFAULT_HOST) From 3b1893c2a8c8d9efe8ef920635b3d626882a1757 Mon Sep 17 00:00:00 2001 From: Chen Nie Date: Fri, 25 Apr 2025 19:51:58 +0800 Subject: [PATCH 08/14] Add tests for router and profile --- src/mcpm/router/router.py | 142 ++++++----- src/mcpm/router/transport.py | 12 +- tests/test_profile.py | 218 ++++++++++++++++ tests/test_router.py | 477 +++++++++++++++++++++++++++++++++++ 4 files changed, 780 insertions(+), 69 deletions(-) create mode 100644 tests/test_profile.py create mode 100644 tests/test_router.py diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index 546a8438..4bbd7f63 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -70,7 +70,7 @@ def __init__( profile_path: str | None = None, strict: bool = False, api_key: str | None = None, - router_config: dict | None = None + router_config: dict | None = None, ) -> None: """ Initialize the router. @@ -184,77 +184,89 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # Collect server tools, prompts, and resources if response.capabilities.tools: tools = await client.session.list_tools() # type: ignore - for tool in tools.tools: - # To make sure tool name is unique across all servers - tool_name = tool.name - if tool_name in self.capabilities_to_server_id["tools"]: - if self.strict: - raise ValueError( - f"Tool {tool_name} already exists. Please use unique tool names across all servers." - ) - else: - # Auto resolve by adding server name prefix - tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}" - self.capabilities_to_server_id["tools"][tool_name] = server_id - self.tools_mapping[tool_name] = tool + # Extract ListToolsResult from ServerResult + tools_result = tools.root + if isinstance(tools_result, types.ListToolsResult): + for tool in tools_result.tools: + # To make sure tool name is unique across all servers + tool_name = tool.name + if tool_name in self.capabilities_to_server_id["tools"]: + if self.strict: + raise ValueError( + f"Tool {tool_name} already exists. Please use unique tool names across all servers." + ) + else: + # Auto resolve by adding server name prefix + tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}" + self.capabilities_to_server_id["tools"][tool_name] = server_id + self.tools_mapping[tool_name] = tool if response.capabilities.prompts: prompts = await client.session.list_prompts() # type: ignore - for prompt in prompts.prompts: - # To make sure prompt name is unique across all servers - prompt_name = prompt.name - if prompt_name in self.capabilities_to_server_id["prompts"]: - if self.strict: - raise ValueError( - f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers." - ) - else: - # Auto resolve by adding server name prefix - prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}" - self.prompts_mapping[prompt_name] = prompt - self.capabilities_to_server_id["prompts"][prompt_name] = server_id + # Extract ListPromptsResult from ServerResult + prompts_result = prompts.root + if isinstance(prompts_result, types.ListPromptsResult): + for prompt in prompts_result.prompts: + # To make sure prompt name is unique across all servers + prompt_name = prompt.name + if prompt_name in self.capabilities_to_server_id["prompts"]: + if self.strict: + raise ValueError( + f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers." + ) + else: + # Auto resolve by adding server name prefix + prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}" + self.prompts_mapping[prompt_name] = prompt + self.capabilities_to_server_id["prompts"][prompt_name] = server_id if response.capabilities.resources: resources = await client.session.list_resources() # type: ignore - for resource in resources.resources: - # To make sure resource URI is unique across all servers - resource_uri = resource.uri - if str(resource_uri) in self.capabilities_to_server_id["resources"]: - if self.strict: - raise ValueError( - f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers." - ) - else: - # Auto resolve by adding server name prefix - host = resource_uri.host - resource_uri = AnyUrl.build( - host=f"{server_id}{RESOURCE_SPLITOR}{host}", - scheme=resource_uri.scheme, - path=resource_uri.path, - username=resource_uri.username, - password=resource_uri.password, - port=resource_uri.port, - query=resource_uri.query, - fragment=resource_uri.fragment, - ) - self.resources_mapping[str(resource_uri)] = resource - self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id + # Extract ListResourcesResult from ServerResult + resources_result = resources.root + if isinstance(resources_result, types.ListResourcesResult): + for resource in resources_result.resources: + # To make sure resource URI is unique across all servers + resource_uri = resource.uri + if str(resource_uri) in self.capabilities_to_server_id["resources"]: + if self.strict: + raise ValueError( + f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers." + ) + else: + # Auto resolve by adding server name prefix + host = resource_uri.host + resource_uri = AnyUrl.build( + host=f"{server_id}{RESOURCE_SPLITOR}{host}", + scheme=resource_uri.scheme, + path=resource_uri.path, + username=resource_uri.username, + password=resource_uri.password, + port=resource_uri.port, + query=resource_uri.query, + fragment=resource_uri.fragment, + ) + self.resources_mapping[str(resource_uri)] = resource + self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id resources_templates = await client.session.list_resource_templates() # type: ignore - for resource_template in resources_templates.resourceTemplates: - # To make sure resource template URI is unique across all servers - resource_template_uri_template = resource_template.uriTemplate - if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]: - if self.strict: - raise ValueError( - f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers." - ) - else: - # Auto resolve by adding server name prefix - resource_template_uri_template = ( - f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}" - ) - self.resources_templates_mapping[resource_template_uri_template] = resource_template - self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id + # Extract ListResourceTemplatesResult from ServerResult + templates_result = resources_templates.root + if isinstance(templates_result, types.ListResourceTemplatesResult): + for resource_template in templates_result.resourceTemplates: + # To make sure resource template URI is unique across all servers + resource_template_uri_template = resource_template.uriTemplate + if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]: + if self.strict: + raise ValueError( + f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers." + ) + else: + # Auto resolve by adding server name prefix + resource_template_uri_template = ( + f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}" + ) + self.resources_templates_mapping[resource_template_uri_template] = resource_template + self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id async def remove_server(self, server_id: str) -> None: """ diff --git a/src/mcpm/router/transport.py b/src/mcpm/router/transport.py index 502e0ee7..180c340b 100644 --- a/src/mcpm/router/transport.py +++ b/src/mcpm/router/transport.py @@ -243,11 +243,15 @@ def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool: if self.api_key is None: logger.debug("API key validation disabled") return True - - # If we have a directly provided API key and it matches the request's API key, return True - if self.api_key is not None and api_key == self.api_key: + + # If we have a directly provided API key, verify it matches + if self.api_key is not None: + # If API key doesn't match, return False + if api_key != self.api_key: + logger.warning("Unauthorized API key") + return False return True - + # Otherwise, fall back to the original validation logic try: config_manager = ConfigManager() diff --git a/tests/test_profile.py b/tests/test_profile.py new file mode 100644 index 00000000..08a8dfb1 --- /dev/null +++ b/tests/test_profile.py @@ -0,0 +1,218 @@ +""" +Tests for the profile module +""" + +import json +import os +import tempfile +from unittest.mock import patch + +import pytest + +from mcpm.profile.profile_config import ProfileConfigManager +from mcpm.schemas.server_config import SSEServerConfig, STDIOServerConfig + + +@pytest.fixture +def temp_profile_file(): + """Create a temporary profile config file for testing""" + with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as f: + # Create a basic profile config + config = { + "test_profile": [{"name": "test-server", "type": "sse", "url": "http://localhost:8080/sse"}], + "empty_profile": [], + } + f.write(json.dumps(config).encode("utf-8")) + temp_path = f.name + + yield temp_path + # Clean up + os.unlink(temp_path) + + +@pytest.fixture +def profile_manager(temp_profile_file): + """Create a ProfileConfigManager with a temp config for testing""" + return ProfileConfigManager(profile_path=temp_profile_file) + + +def test_profile_manager_init_default_path(): + """Test that the profile manager initializes with default path""" + with patch("mcpm.profile.profile_config.os.path.exists", return_value=False): + manager = ProfileConfigManager() + assert manager.profile_path == os.path.expanduser("~/.config/mcpm/profiles.json") + assert manager._profiles == {} + + +def test_profile_manager_init_custom_path(temp_profile_file): + """Test that the profile manager initializes with a custom path""" + manager = ProfileConfigManager(profile_path=temp_profile_file) + assert manager.profile_path == temp_profile_file + assert "test_profile" in manager._profiles + assert "empty_profile" in manager._profiles + + +def test_load_profiles_not_exists(): + """Test loading profiles when file doesn't exist""" + with patch("mcpm.profile.profile_config.os.path.exists", return_value=False): + manager = ProfileConfigManager() + profiles = manager._load_profiles() + assert profiles == {} + + +def test_load_profiles(profile_manager): + """Test loading profiles from file""" + profiles = profile_manager._load_profiles() + assert "test_profile" in profiles + assert "empty_profile" in profiles + assert len(profiles["test_profile"]) == 1 + assert len(profiles["empty_profile"]) == 0 + + +def test_new_profile(profile_manager): + """Test creating a new profile""" + # Create new profile + result = profile_manager.new_profile("new_profile") + assert result is True + assert "new_profile" in profile_manager._profiles + assert profile_manager._profiles["new_profile"] == [] + + # Test creating existing profile + result = profile_manager.new_profile("test_profile") + assert result is False + + +def test_get_profile(profile_manager): + """Test getting a profile""" + # Get existing profile + profile = profile_manager.get_profile("test_profile") + assert profile is not None + assert len(profile) == 1 + assert profile[0].name == "test-server" + + # Get non-existent profile + profile = profile_manager.get_profile("non_existent") + assert profile is None + + +def test_get_profile_server(profile_manager): + """Test getting a server from a profile""" + # Get existing server + server = profile_manager.get_profile_server("test_profile", "test-server") + assert server is not None + assert server.name == "test-server" + + # Get non-existent server + server = profile_manager.get_profile_server("test_profile", "non-existent") + assert server is None + + # Get server from non-existent profile + server = profile_manager.get_profile_server("non_existent", "test-server") + assert server is None + + +def test_set_profile_new_server(profile_manager): + """Test setting a new server in a profile""" + new_server = SSEServerConfig(name="new-server", url="http://localhost:8081/sse") + result = profile_manager.set_profile("test_profile", new_server) + assert result is True + + # Verify server was added + servers = profile_manager.get_profile("test_profile") + assert len(servers) == 2 + server_names = [s.name for s in servers] + assert "new-server" in server_names + + +def test_set_profile_update_server(profile_manager): + """Test updating an existing server in a profile""" + updated_server = SSEServerConfig(name="test-server", url="http://localhost:8082/sse") + result = profile_manager.set_profile("test_profile", updated_server) + assert result is True + + # Verify server was updated + server = profile_manager.get_profile_server("test_profile", "test-server") + assert server is not None + assert server.url == "http://localhost:8082/sse" + + +def test_set_profile_new_profile(profile_manager): + """Test setting a server in a new profile""" + new_server = STDIOServerConfig(name="stdio-server", command="test-command", args=["--arg1", "--arg2"]) + result = profile_manager.set_profile("new_profile", new_server) + assert result is True + + # Verify profile and server were created + profile = profile_manager.get_profile("new_profile") + assert profile is not None + assert len(profile) == 1 + assert profile[0].name == "stdio-server" + + +def test_delete_profile(profile_manager): + """Test deleting a profile""" + # Delete existing profile + result = profile_manager.delete_profile("test_profile") + assert result is True + assert "test_profile" not in profile_manager._profiles + + # Delete non-existent profile + result = profile_manager.delete_profile("non_existent") + assert result is False + + +def test_list_profiles(profile_manager): + """Test listing all profiles""" + profiles = profile_manager.list_profiles() + assert "test_profile" in profiles + assert "empty_profile" in profiles + assert len(profiles["test_profile"]) == 1 + assert len(profiles["empty_profile"]) == 0 + + +def test_rename_profile(profile_manager): + """Test renaming a profile""" + # Rename existing profile + result = profile_manager.rename_profile("test_profile", "renamed_profile") + assert result is True + assert "test_profile" not in profile_manager._profiles + assert "renamed_profile" in profile_manager._profiles + + # Rename to existing profile name + result = profile_manager.rename_profile("renamed_profile", "empty_profile") + assert result is False + + # Rename non-existent profile + result = profile_manager.rename_profile("non_existent", "new_name") + assert result is False + + +def test_remove_server(profile_manager): + """Test removing a server from a profile""" + # Remove existing server + result = profile_manager.remove_server("test_profile", "test-server") + assert result is True + + # Verify server was removed + profile = profile_manager.get_profile("test_profile") + assert len(profile) == 0 + + # Remove non-existent server + result = profile_manager.remove_server("test_profile", "non-existent") + assert result is False + + # Remove from non-existent profile + result = profile_manager.remove_server("non_existent", "test-server") + assert result is False + + +def test_reload(profile_manager): + """Test reloading profiles from file""" + # Modify profiles + profile_manager._profiles = {} + assert len(profile_manager._profiles) == 0 + + # Reload + profile_manager.reload() + assert "test_profile" in profile_manager._profiles + assert "empty_profile" in profile_manager._profiles diff --git a/tests/test_router.py b/tests/test_router.py new file mode 100644 index 00000000..f8b87a20 --- /dev/null +++ b/tests/test_router.py @@ -0,0 +1,477 @@ +""" +Tests for the router module +""" + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from mcp import InitializeResult +from mcp.types import ListToolsResult, ServerCapabilities, ServerResult, Tool, ToolsCapability + +from mcpm.router.client_connection import ServerConnection +from mcpm.router.router import MCPRouter +from mcpm.schemas.server_config import SSEServerConfig +from mcpm.utils.config import TOOL_SPLITOR + + +@pytest.fixture +def mock_server_connection(): + """Create a mock server connection for testing""" + mock_conn = MagicMock(spec=ServerConnection) + mock_conn.healthy.return_value = True + mock_conn.request_for_shutdown = AsyncMock() + + # Create valid ServerCapabilities with ToolsCapability + tools_capability = ToolsCapability(listChanged=False) + capabilities = ServerCapabilities( + prompts=None, resources=None, tools=tools_capability, logging=None, experimental={} + ) + + # Mock session initialized response + mock_conn.session_initialized_response = InitializeResult( + protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "test-server", "version": "1.0.0"} + ) + + # Mock session + mock_session = AsyncMock() + # Create a valid tool with proper inputSchema structure + mock_tool = Tool(name="test-tool", description="A test tool", inputSchema={"type": "object", "properties": {}}) + # Create a ListToolsResult to be the root of ServerResult + tools_result = ListToolsResult(tools=[mock_tool]) + # Create a ServerResult with ListToolsResult as its root + mock_list_tools_result = ServerResult(root=tools_result) + mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result) + mock_conn.session = mock_session + + return mock_conn + + +@pytest.mark.asyncio +async def test_router_init(): + """Test initializing the router""" + # Test with default values + router = MCPRouter() + assert router.profile_manager is not None + assert router.watcher is None + assert router.strict is False + assert router.api_key is None + assert router.router_config is None + + # Test with custom values + router_config = {"host": "custom-host", "port": 9000} + router = MCPRouter( + reload_server=True, + strict=True, + api_key="test-api-key", + router_config=router_config, + ) + + assert router.watcher is not None + assert router.strict is True + assert router.api_key == "test-api-key" + assert router.router_config == router_config + + +def test_create_global_config(): + """Test creating a global config from router config""" + router_config = {"host": "custom-host", "port": 9000, "share_address": "custom-share-address"} + + with patch("mcpm.router.router.ConfigManager") as mock_config_manager: + mock_instance = Mock() + mock_config_manager.return_value = mock_instance + + # Test without API key + router = MCPRouter(router_config=router_config) + router.create_global_config() + mock_instance.save_share_config.assert_not_called() + mock_instance.save_router_config.assert_not_called() + + # Test with API key + router = MCPRouter(api_key="test-api-key", router_config=router_config) + router.create_global_config() + mock_instance.save_share_config.assert_called_once_with(api_key="test-api-key") + mock_instance.save_router_config.assert_called_once_with("custom-host", 9000, "custom-share-address") + + +@pytest.mark.asyncio +async def test_add_server(mock_server_connection): + """Test adding a server to the router""" + router = MCPRouter() + + # Mock get_active_servers to return all server IDs + def mock_get_active_servers(_profile): + return list(router.server_sessions.keys()) + + # Patch the _patch_handler_func method to use our mock + with patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler: + mock_patch_handler.return_value.get_active_servers = mock_get_active_servers + + server_config = SSEServerConfig(name="test-server", url="http://localhost:8080/sse") + + with patch("mcpm.router.router.ServerConnection", return_value=mock_server_connection): + await router.add_server("test-server", server_config) + + # Verify server was added + assert "test-server" in router.server_sessions + assert router.server_sessions["test-server"] == mock_server_connection + + # Verify capabilities were stored + assert "test-server" in router.capabilities_mapping + + # Verify tool was stored + assert "test-tool" in router.tools_mapping + assert router.capabilities_to_server_id["tools"]["test-tool"] == "test-server" + + # Test adding duplicate server + with pytest.raises(ValueError): + await router.add_server("test-server", server_config) + + +@pytest.mark.asyncio +async def test_add_server_unhealthy(): + """Test adding an unhealthy server""" + router = MCPRouter() + server_config = SSEServerConfig(name="unhealthy-server", url="http://localhost:8080/sse") + + mock_conn = MagicMock(spec=ServerConnection) + mock_conn.healthy.return_value = False + + with patch("mcpm.router.router.ServerConnection", return_value=mock_conn): + with pytest.raises(ValueError, match="Failed to connect to server unhealthy-server"): + await router.add_server("unhealthy-server", server_config) + + +@pytest.mark.asyncio +async def test_add_server_duplicate_tool_strict(): + """Test adding a server with duplicate tool name in strict mode""" + router = MCPRouter(strict=True) + + # Mock get_active_servers to return all server IDs + def mock_get_active_servers(_profile): + return list(router.server_sessions.keys()) + + # Patch the _patch_handler_func method to use our mock + with patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler: + mock_patch_handler.return_value.get_active_servers = mock_get_active_servers + + server_config = SSEServerConfig(name="test-server", url="http://localhost:8080/sse") + + # Add first server with a tool + mock_conn1 = MagicMock(spec=ServerConnection) + mock_conn1.healthy.return_value = True + mock_conn1.request_for_shutdown = AsyncMock() + + # Create valid ServerCapabilities with ToolsCapability + tools_capability = ToolsCapability(listChanged=False) + capabilities = ServerCapabilities( + prompts=None, resources=None, tools=tools_capability, logging=None, experimental={} + ) + + mock_conn1.session_initialized_response = InitializeResult( + protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "test-server", "version": "1.0.0"} + ) + + mock_session1 = AsyncMock() + mock_tool = Tool( + name="duplicate-tool", description="A test tool", inputSchema={"type": "object", "properties": {}} + ) + # Create a ListToolsResult to be the root of ServerResult + tools_result = ListToolsResult(tools=[mock_tool]) + # Create a ServerResult with ListToolsResult as its root + mock_list_tools_result = ServerResult(root=tools_result) + mock_session1.list_tools = AsyncMock(return_value=mock_list_tools_result) + mock_conn1.session = mock_session1 + + # Add second server with same tool name + mock_conn2 = MagicMock(spec=ServerConnection) + mock_conn2.healthy.return_value = True + mock_conn2.request_for_shutdown = AsyncMock() + + mock_conn2.session_initialized_response = InitializeResult( + protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "second-server", "version": "1.0.0"} + ) + + mock_session2 = AsyncMock() + mock_session2.list_tools = AsyncMock(return_value=mock_list_tools_result) + mock_conn2.session = mock_session2 + + with patch("mcpm.router.router.ServerConnection", side_effect=[mock_conn1, mock_conn2]): + # Add first server should succeed + await router.add_server("test-server", server_config) + + # Add second server with duplicate tool should fail in strict mode + with pytest.raises(ValueError, match="Tool duplicate-tool already exists"): + await router.add_server("second-server", server_config) + + +@pytest.mark.asyncio +async def test_add_server_duplicate_tool_non_strict(): + """Test adding a server with duplicate tool name in non-strict mode""" + router = MCPRouter(strict=False) + + # Mock get_active_servers to return all server IDs + def mock_get_active_servers(_profile): + return list(router.server_sessions.keys()) + + # Patch the _patch_handler_func method to use our mock + with patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler: + mock_patch_handler.return_value.get_active_servers = mock_get_active_servers + + server_config = SSEServerConfig(name="test-server", url="http://localhost:8080/sse") + second_server_config = SSEServerConfig(name="second-server", url="http://localhost:8081/sse") + + # Add first server with a tool + mock_conn1 = MagicMock(spec=ServerConnection) + mock_conn1.healthy.return_value = True + mock_conn1.request_for_shutdown = AsyncMock() + + # Create valid ServerCapabilities with ToolsCapability + tools_capability = ToolsCapability(listChanged=False) + capabilities = ServerCapabilities( + prompts=None, resources=None, tools=tools_capability, logging=None, experimental={} + ) + + mock_conn1.session_initialized_response = InitializeResult( + protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "test-server", "version": "1.0.0"} + ) + + mock_session1 = AsyncMock() + mock_tool = Tool( + name="duplicate-tool", description="A test tool", inputSchema={"type": "object", "properties": {}} + ) + # Create a ListToolsResult to be the root of ServerResult + tools_result = ListToolsResult(tools=[mock_tool]) + # Create a ServerResult with ListToolsResult as its root + mock_list_tools_result = ServerResult(root=tools_result) + mock_session1.list_tools = AsyncMock(return_value=mock_list_tools_result) + mock_conn1.session = mock_session1 + + # Add second server with same tool name + mock_conn2 = MagicMock(spec=ServerConnection) + mock_conn2.healthy.return_value = True + mock_conn2.request_for_shutdown = AsyncMock() + + mock_conn2.session_initialized_response = InitializeResult( + protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "second-server", "version": "1.0.0"} + ) + + mock_session2 = AsyncMock() + mock_session2.list_tools = AsyncMock(return_value=mock_list_tools_result) + mock_conn2.session = mock_session2 + + with patch("mcpm.router.router.ServerConnection", side_effect=[mock_conn1, mock_conn2]): + # Add first server + await router.add_server("test-server", server_config) + assert "duplicate-tool" in router.tools_mapping + assert router.capabilities_to_server_id["tools"]["duplicate-tool"] == "test-server" + + # Add second server with duplicate tool - should prefix the tool name + await router.add_server("second-server", second_server_config) + prefixed_tool_name = f"second-server{TOOL_SPLITOR}duplicate-tool" + assert prefixed_tool_name in router.capabilities_to_server_id["tools"] + assert router.capabilities_to_server_id["tools"][prefixed_tool_name] == "second-server" + + +@pytest.mark.asyncio +async def test_remove_server(): + """Test removing a server from the router""" + router = MCPRouter() + + # Setup mock server session with an awaitable request_for_shutdown + mock_session = AsyncMock() + mock_session.close = AsyncMock() + + mock_server = MagicMock(spec=ServerConnection) + mock_server.session = mock_session + mock_server.request_for_shutdown = AsyncMock() + + # Mock server and capabilities + router.server_sessions = {"test-server": mock_server} + router.capabilities_mapping = {"test-server": {"tools": True}} + router.capabilities_to_server_id = {"tools": {"test-tool": "test-server"}} + router.tools_mapping = {"test-tool": MagicMock()} + + # Remove server + await router.remove_server("test-server") + + # Verify server was removed + assert "test-server" not in router.server_sessions + assert "test-server" not in router.capabilities_mapping + assert "test-tool" not in router.capabilities_to_server_id["tools"] + assert "test-tool" not in router.tools_mapping + + # Verify request_for_shutdown was called + mock_server.request_for_shutdown.assert_called_once() + + # Test removing non-existent server + with pytest.raises(ValueError, match="Server with ID non-existent does not exist"): + await router.remove_server("non-existent") + + +@pytest.mark.asyncio +async def test_update_servers(mock_server_connection): + """Test updating servers based on configuration""" + router = MCPRouter() + + # Mock get_active_servers to return all server IDs + def mock_get_active_servers(_profile): + return list(router.server_sessions.keys()) + + # Patch the _patch_handler_func method to use our mock + with patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler: + mock_patch_handler.return_value.get_active_servers = mock_get_active_servers + + # Setup initial servers with awaitable request_for_shutdown + mock_old_server = MagicMock(spec=ServerConnection) + mock_old_server.session = AsyncMock() + mock_old_server.request_for_shutdown = AsyncMock() + + router.server_sessions = {"old-server": mock_old_server} + # Initialize capabilities_mapping for the old server + router.capabilities_mapping = {"old-server": {"tools": True}} + + # Configure new servers + server_configs = [SSEServerConfig(name="test-server", url="http://localhost:8080/sse")] + + with patch("mcpm.router.router.ServerConnection", return_value=mock_server_connection): + await router.update_servers(server_configs) + + # Verify old server was removed + assert "old-server" not in router.server_sessions + mock_old_server.request_for_shutdown.assert_called_once() + + # Verify new server was added + assert "test-server" in router.server_sessions + + # Test with empty configs - should not change anything + router.server_sessions = {"test-server": mock_server_connection} + await router.update_servers([]) + assert "test-server" in router.server_sessions + + +@pytest.mark.asyncio +async def test_update_servers_error_handling(): + """Test error handling during server updates""" + router = MCPRouter() + + # Setup initial servers with awaitable request_for_shutdown + mock_old_server = MagicMock(spec=ServerConnection) + mock_old_server.session = AsyncMock() + mock_old_server.request_for_shutdown = AsyncMock() + + router.server_sessions = {"old-server": mock_old_server} + # Initialize capabilities_mapping for the old server + router.capabilities_mapping = {"old-server": {"tools": True}} + + # Configure new servers + server_configs = [SSEServerConfig(name="test-server", url="http://localhost:8080/sse")] + + # Mock add_server to raise exception + with patch.object(router, "add_server", side_effect=Exception("Test error")): + # Should not raise exception + await router.update_servers(server_configs) + + # Old server should still be removed + assert "old-server" not in router.server_sessions + mock_old_server.request_for_shutdown.assert_called_once() + + # New server should not be added + assert "test-server" not in router.server_sessions + + +@pytest.mark.asyncio +async def test_router_sse_transport_no_api_key(): + """Test RouterSseTransport with no API key (authentication disabled)""" + + from mcpm.router.transport import RouterSseTransport + + # Create a RouterSseTransport with no API key + transport = RouterSseTransport("/messages/", api_key=None) + + # Create a mock scope + mock_scope = {"type": "http"} + + # Test _validate_api_key method directly + assert transport._validate_api_key(mock_scope, api_key=None) + assert transport._validate_api_key(mock_scope, api_key="any-key") + + # Test with various API key values - all should be allowed + assert transport._validate_api_key(mock_scope, api_key="test-key") + assert transport._validate_api_key(mock_scope, api_key="invalid-key") + assert transport._validate_api_key(mock_scope, api_key="") + + +@pytest.mark.asyncio +async def test_router_sse_transport_with_api_key(): + """Test RouterSseTransport with API key (authentication enabled)""" + + from mcpm.router.transport import RouterSseTransport + + # Create a RouterSseTransport with an API key + transport = RouterSseTransport("/messages/", api_key="correct-api-key") + + # Create a mock scope + mock_scope = {"type": "http"} + + # Test _validate_api_key method directly + # With the correct API key + assert transport._validate_api_key(mock_scope, api_key="correct-api-key") + + # With an incorrect API key + assert not transport._validate_api_key(mock_scope, api_key="wrong-api-key") + + # With no API key + assert not transport._validate_api_key(mock_scope, api_key=None) + + # Test with empty string + assert not transport._validate_api_key(mock_scope, api_key="") + + +@pytest.mark.asyncio +async def test_get_sse_server_app_with_api_key(): + """Test that the API key is passed to RouterSseTransport when creating the server app""" + router = MCPRouter(api_key="test-api-key") + + # Patch the RouterSseTransport constructor and get_active_servers method + with ( + patch("mcpm.router.router.RouterSseTransport") as mock_transport, + patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler, + ): + # Set up mocks for initialization + def mock_get_active_servers(_profile): + return list(router.server_sessions.keys()) + + mock_patch_handler.return_value.get_active_servers = mock_get_active_servers + + # Call the method + await router.get_sse_server_app() + + # Check that RouterSseTransport was created with the correct API key + mock_transport.assert_called_once() + call_kwargs = mock_transport.call_args[1] + assert call_kwargs.get("api_key") == "test-api-key" + + +@pytest.mark.asyncio +async def test_get_sse_server_app_without_api_key(): + """Test that None is passed to RouterSseTransport when no API key is provided""" + router = MCPRouter() # No API key + + # Patch the RouterSseTransport constructor and get_active_servers method + with ( + patch("mcpm.router.router.RouterSseTransport") as mock_transport, + patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler, + ): + # Set up mocks for initialization + def mock_get_active_servers(_profile): + return list(router.server_sessions.keys()) + + mock_patch_handler.return_value.get_active_servers = mock_get_active_servers + + # Call the method + await router.get_sse_server_app() + + # Check that RouterSseTransport was created with api_key=None + mock_transport.assert_called_once() + call_kwargs = mock_transport.call_args[1] + assert call_kwargs.get("api_key") is None From a9c942a8c3746700470ce54225d90a46f2c010a4 Mon Sep 17 00:00:00 2001 From: Chen Nie Date: Sun, 27 Apr 2025 11:52:47 +0800 Subject: [PATCH 09/14] refactor router config --- .cursor/rules/pytest.mdc | 6 +++ src/mcpm/__init__.py | 4 +- src/mcpm/router/router.py | 65 +++++++++++++++----------------- src/mcpm/router/router_config.py | 17 +++++++++ src/mcpm/router/transport.py | 20 +++++----- tests/test_router.py | 38 +++++++++---------- 6 files changed, 84 insertions(+), 66 deletions(-) create mode 100644 .cursor/rules/pytest.mdc create mode 100644 src/mcpm/router/router_config.py diff --git a/.cursor/rules/pytest.mdc b/.cursor/rules/pytest.mdc new file mode 100644 index 00000000..b3bdf4c6 --- /dev/null +++ b/.cursor/rules/pytest.mdc @@ -0,0 +1,6 @@ +--- +description: +globs: *.py +alwaysApply: false +--- +always run pytest at the end of a major change \ No newline at end of file diff --git a/src/mcpm/__init__.py b/src/mcpm/__init__.py index 5d77ba78..164c1afe 100644 --- a/src/mcpm/__init__.py +++ b/src/mcpm/__init__.py @@ -5,7 +5,9 @@ # Import version from internal module # Import router module from . import router +from .router.router import MCPRouter +from .router.router_config import RouterConfig from .version import __version__ # Define what symbols are exported from this package -__all__ = ["__version__", "router"] +__all__ = ["__version__", "router", "MCPRouter", "RouterConfig"] diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index 4bbd7f63..63097cf6 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -24,9 +24,6 @@ from mcpm.profile.profile_config import ProfileConfigManager from mcpm.schemas.server_config import ServerConfig from mcpm.utils.config import ( - DEFAULT_HOST, - DEFAULT_PORT, - DEFAULT_SHARE_ADDRESS, PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, @@ -35,6 +32,7 @@ ) from .client_connection import ServerConnection +from .router_config import RouterConfig from .transport import RouterSseTransport from .watcher import ConfigWatcher @@ -49,15 +47,16 @@ class MCPRouter: Example: ```python # Initialize with a custom API key - router = MCPRouter(api_key="your-api-key") + router = MCPRouter(router_config=RouterConfig(api_key="your-api-key")) # Initialize with custom router configuration - router_config = { - "host": "localhost", - "port": 8080, - "share_address": "custom.share.address:8080" - } - router = MCPRouter(api_key="your-api-key", router_config=router_config) + router_config = RouterConfig( + host="localhost", + port=8080, + share_address="custom.share.address:8080", + api_key="your-api-key" + ) + router = MCPRouter(router_config=router_config) # Create a global config from the router's configuration router.create_global_config() @@ -68,18 +67,13 @@ def __init__( self, reload_server: bool = False, profile_path: str | None = None, - strict: bool = False, - api_key: str | None = None, - router_config: dict | None = None, + router_config: RouterConfig | None = None, ) -> None: """ Initialize the router. :param reload_server: Whether to reload the server when the config changes :param profile_path: Path to the profile file - :param strict: Whether to use strict mode for duplicated tool name. - If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix - :param api_key: Optional API key to use for authentication. :param router_config: Optional router configuration to use instead of the global config """ self.server_sessions: t.Dict[str, ServerConnection] = {} @@ -94,9 +88,7 @@ def __init__( self.watcher: Optional[ConfigWatcher] = None if reload_server: self.watcher = ConfigWatcher(self.profile_manager.profile_path) - self.strict: bool = strict - self.api_key = api_key - self.router_config = router_config + self.router_config = router_config if router_config is not None else RouterConfig() def create_global_config(self) -> None: """ @@ -104,17 +96,19 @@ def create_global_config(self) -> None: This is useful if you want to initialize the router with a config but also want that config to be available globally. """ - if self.api_key is not None: - config_manager = ConfigManager() - # Save the API key to the global config - config_manager.save_share_config(api_key=self.api_key) - - # If router_config is provided, save it to the global config - if self.router_config is not None: - host = self.router_config.get("host", DEFAULT_HOST) - port = self.router_config.get("port", DEFAULT_PORT) - share_address = self.router_config.get("share_address", DEFAULT_SHARE_ADDRESS) - config_manager.save_router_config(host, port, share_address) + # Skip if router_config is None or there's no explicit api_key set + if self.router_config is None or self.router_config.api_key is None: + return + + config_manager = ConfigManager() + + # Save the API key to the global config + config_manager.save_share_config(api_key=self.router_config.api_key) + + # Save router configuration to the global config + config_manager.save_router_config( + self.router_config.host, self.router_config.port, self.router_config.share_address + ) def get_unique_servers(self) -> list[ServerConfig]: profiles = self.profile_manager.list_profiles() @@ -191,7 +185,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # To make sure tool name is unique across all servers tool_name = tool.name if tool_name in self.capabilities_to_server_id["tools"]: - if self.strict: + if self.router_config.strict: raise ValueError( f"Tool {tool_name} already exists. Please use unique tool names across all servers." ) @@ -210,7 +204,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # To make sure prompt name is unique across all servers prompt_name = prompt.name if prompt_name in self.capabilities_to_server_id["prompts"]: - if self.strict: + if self.router_config.strict: raise ValueError( f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers." ) @@ -229,7 +223,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # To make sure resource URI is unique across all servers resource_uri = resource.uri if str(resource_uri) in self.capabilities_to_server_id["resources"]: - if self.strict: + if self.router_config.strict: raise ValueError( f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers." ) @@ -256,7 +250,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # To make sure resource template URI is unique across all servers resource_template_uri_template = resource_template.uriTemplate if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]: - if self.strict: + if self.router_config.strict: raise ValueError( f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers." ) @@ -564,7 +558,8 @@ async def get_sse_server_app( await self.initialize_router() # Pass the API key to the RouterSseTransport - sse = RouterSseTransport("/messages/", api_key=self.api_key) + api_key = None if self.router_config is None else self.router_config.api_key + sse = RouterSseTransport("/messages/", api_key=api_key) async def handle_sse(request: Request) -> None: async with sse.connect_sse( diff --git a/src/mcpm/router/router_config.py b/src/mcpm/router/router_config.py new file mode 100644 index 00000000..92ae2228 --- /dev/null +++ b/src/mcpm/router/router_config.py @@ -0,0 +1,17 @@ +from typing import Optional + +from pydantic import BaseModel + +from mcpm.utils.config import DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SHARE_ADDRESS + + +class RouterConfig(BaseModel): + """ + Router configuration model for MCPRouter + """ + + host: str = DEFAULT_HOST + port: int = DEFAULT_PORT + share_address: str = DEFAULT_SHARE_ADDRESS + api_key: Optional[str] = None + strict: bool = False diff --git a/src/mcpm/router/transport.py b/src/mcpm/router/transport.py index 180c340b..ac3399c5 100644 --- a/src/mcpm/router/transport.py +++ b/src/mcpm/router/transport.py @@ -245,14 +245,11 @@ def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool: return True # If we have a directly provided API key, verify it matches - if self.api_key is not None: - # If API key doesn't match, return False - if api_key != self.api_key: - logger.warning("Unauthorized API key") - return False + if api_key == self.api_key: return True - # Otherwise, fall back to the original validation logic + # At this point, self.api_key is not None but doesn't match the provided api_key + # Let's check if this is a share URL that needs special validation try: config_manager = ConfigManager() host = get_key_from_scope(scope, key_name="host") or "" @@ -264,10 +261,11 @@ def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool: share_host_name = urlsplit(share_config["url"]).hostname if share_config["url"] and (host_name == share_host_name or host_name != router_config["host"]): share_api_key = share_config["api_key"] - if api_key != share_api_key: - logger.warning("Unauthorized API key") - return False + if api_key == share_api_key: + return True except Exception as e: logger.error(f"Failed to validate API key: {e}") - return False - return True + + # If we reach here, the API key is invalid + logger.warning("Unauthorized API key") + return False diff --git a/tests/test_router.py b/tests/test_router.py index f8b87a20..dc7fa895 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -10,6 +10,7 @@ from mcpm.router.client_connection import ServerConnection from mcpm.router.router import MCPRouter +from mcpm.router.router_config import RouterConfig from mcpm.schemas.server_config import SSEServerConfig from mcpm.utils.config import TOOL_SPLITOR @@ -53,41 +54,40 @@ async def test_router_init(): router = MCPRouter() assert router.profile_manager is not None assert router.watcher is None - assert router.strict is False - assert router.api_key is None - assert router.router_config is None + assert router.router_config is not None + assert router.router_config.strict is False # Test with custom values - router_config = {"host": "custom-host", "port": 9000} + config = RouterConfig( + host="custom-host", port=9000, share_address="custom-share-address", api_key="test-api-key", strict=True + ) router = MCPRouter( reload_server=True, - strict=True, - api_key="test-api-key", - router_config=router_config, + router_config=config, ) assert router.watcher is not None - assert router.strict is True - assert router.api_key == "test-api-key" - assert router.router_config == router_config + assert router.router_config == config + assert router.router_config.api_key == "test-api-key" + assert router.router_config.strict is True def test_create_global_config(): """Test creating a global config from router config""" - router_config = {"host": "custom-host", "port": 9000, "share_address": "custom-share-address"} + config = RouterConfig(host="custom-host", port=9000, share_address="custom-share-address", api_key="test-api-key") with patch("mcpm.router.router.ConfigManager") as mock_config_manager: mock_instance = Mock() mock_config_manager.return_value = mock_instance - # Test without API key - router = MCPRouter(router_config=router_config) + # Test without router_config + router = MCPRouter() router.create_global_config() mock_instance.save_share_config.assert_not_called() mock_instance.save_router_config.assert_not_called() - # Test with API key - router = MCPRouter(api_key="test-api-key", router_config=router_config) + # Test with router_config + router = MCPRouter(router_config=config) router.create_global_config() mock_instance.save_share_config.assert_called_once_with(api_key="test-api-key") mock_instance.save_router_config.assert_called_once_with("custom-host", 9000, "custom-share-address") @@ -144,7 +144,7 @@ async def test_add_server_unhealthy(): @pytest.mark.asyncio async def test_add_server_duplicate_tool_strict(): """Test adding a server with duplicate tool name in strict mode""" - router = MCPRouter(strict=True) + router = MCPRouter(router_config=RouterConfig(strict=True)) # Mock get_active_servers to return all server IDs def mock_get_active_servers(_profile): @@ -207,7 +207,7 @@ def mock_get_active_servers(_profile): @pytest.mark.asyncio async def test_add_server_duplicate_tool_non_strict(): """Test adding a server with duplicate tool name in non-strict mode""" - router = MCPRouter(strict=False) + router = MCPRouter(router_config=RouterConfig(strict=False)) # Mock get_active_servers to return all server IDs def mock_get_active_servers(_profile): @@ -430,7 +430,7 @@ async def test_router_sse_transport_with_api_key(): @pytest.mark.asyncio async def test_get_sse_server_app_with_api_key(): """Test that the API key is passed to RouterSseTransport when creating the server app""" - router = MCPRouter(api_key="test-api-key") + router = MCPRouter(router_config=RouterConfig(api_key="test-api-key")) # Patch the RouterSseTransport constructor and get_active_servers method with ( @@ -455,7 +455,7 @@ def mock_get_active_servers(_profile): @pytest.mark.asyncio async def test_get_sse_server_app_without_api_key(): """Test that None is passed to RouterSseTransport when no API key is provided""" - router = MCPRouter() # No API key + router = MCPRouter() # No API key or router_config # Patch the RouterSseTransport constructor and get_active_servers method with ( From 40f21021913bf5d40648a81c383ddba894b7394c Mon Sep 17 00:00:00 2001 From: Chen Nie Date: Sun, 27 Apr 2025 16:48:16 +0800 Subject: [PATCH 10/14] add --auth/--no-auth router config --- .cursor/rules/pytest.mdc | 3 +- src/mcpm/commands/router.py | 71 +++++++++++++++++++++++++++++++++---- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/.cursor/rules/pytest.mdc b/.cursor/rules/pytest.mdc index b3bdf4c6..f56a7985 100644 --- a/.cursor/rules/pytest.mdc +++ b/.cursor/rules/pytest.mdc @@ -3,4 +3,5 @@ description: globs: *.py alwaysApply: false --- -always run pytest at the end of a major change \ No newline at end of file +always run pytest at the end of a major change +always run ruff lint at then end of any code changes \ No newline at end of file diff --git a/src/mcpm/commands/router.py b/src/mcpm/commands/router.py index b6a8932c..8e6a17bb 100644 --- a/src/mcpm/commands/router.py +++ b/src/mcpm/commands/router.py @@ -41,6 +41,7 @@ def is_process_running(pid): except Exception: return False + def is_port_listening(host, port) -> bool: """ Check if the specified (host, port) is being listened on. @@ -133,10 +134,18 @@ def start_router(verbose): return # get router config - config = ConfigManager().get_router_config() + config_manager = ConfigManager() + config = config_manager.get_router_config() host = config["host"] port = config["port"] + # Check if we have an API key, if not, create one + share_config = config_manager.read_share_config() + if share_config.get("api_key") is None: + api_key = secrets.token_urlsafe(32) + config_manager.save_share_config(api_key=api_key) + console.print("[bold green]Created API key for router authentication[/]") + # prepare uvicorn command uvicorn_cmd = [ sys.executable, @@ -185,9 +194,29 @@ def start_router(verbose): pid = process.pid write_pid_file(pid) + # Display router started information console.print(f"[bold green]MCPRouter started[/] at http://{host}:{port} (PID: {pid})") console.print(f"Log file: {log_file}") - console.print("Use 'mcpm router off' to stop the router.") + + # Display connection instructions + console.print("\n[bold cyan]Connection Information:[/]") + + # Get API key if available + api_key = config_manager.read_share_config().get("api_key") + + # Show URL with or without authentication based on API key availability + if api_key: + # Show authenticated URL + console.print(f"SSE Server URL: [green]http://{host}:{port}/sse?s={api_key}[/]") + console.print("\n[bold cyan]To use a specific profile with authentication:[/]") + console.print(f"[green]http://{host}:{port}/sse?s={api_key}&profile=[/]") + else: + # Show URL without authentication + console.print(f"SSE Server URL: [green]http://{host}:{port}/sse[/]") + console.print("\n[bold cyan]To use a specific profile:[/]") + console.print(f"[green]http://{host}:{port}/sse?profile=[/]") + + console.print("\n[yellow]Use 'mcpm router off' to stop the router.[/]") except Exception as e: console.print(f"[bold red]Error:[/] Failed to start MCPRouter: {e}") @@ -197,17 +226,22 @@ def start_router(verbose): @click.option("-H", "--host", type=str, help="Host to bind the SSE server to") @click.option("-p", "--port", type=int, help="Port to bind the SSE server to") @click.option("-a", "--address", type=str, help="Remote address to share the router") +@click.option( + "--auth/--no-auth", default=True, is_flag=True, help="Enable/disable API key authentication (default: enabled)" +) @click.help_option("-h", "--help") -def set_router_config(host, port, address): +def set_router_config(host, port, address, auth): """Set MCPRouter global configuration. Example: mcpm router set -H localhost -p 8888 mcpm router set --host 127.0.0.1 --port 9000 + mcpm router set --no-auth # disable authentication + mcpm router set --auth # enable authentication """ - if not host and not port and not address: + if not host and not port and not address and auth is None: console.print( - "[yellow]No changes were made. Please specify at least one option (--host, --port, or --address)[/]" + "[yellow]No changes were made. Please specify at least one option (--host, --port, --address, --auth/--no-auth)[/]" ) return @@ -220,7 +254,32 @@ def set_router_config(host, port, address): port = port or current_config["port"] share_address = address or current_config["share_address"] - # save config + # Handle authentication setting + share_config = config_manager.read_share_config() + current_api_key = share_config.get("api_key") + + if auth: + # Enable authentication + if current_api_key is None: + # Generate a new API key if authentication is enabled but no key exists + api_key = secrets.token_urlsafe(32) + config_manager.save_share_config( + share_url=share_config.get("url"), share_pid=share_config.get("pid"), api_key=api_key + ) + console.print("[bold green]API key authentication enabled.[/] Generated new API key.") + else: + console.print("[bold green]API key authentication enabled.[/] Using existing API key.") + else: + # Disable authentication by clearing the API key + if current_api_key is not None: + config_manager.save_share_config( + share_url=share_config.get("url"), share_pid=share_config.get("pid"), api_key=None + ) + console.print("[bold yellow]API key authentication disabled.[/]") + else: + console.print("[bold yellow]API key authentication was already disabled.[/]") + + # save router config if config_manager.save_router_config(host, port, share_address): console.print( f"[bold green]Router configuration updated:[/] host={host}, port={port}, share_address={share_address}" From 3d215d546138fa8c9c1c3eea77520c2caff2c19a Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Mon, 28 Apr 2025 11:31:34 +0800 Subject: [PATCH 11/14] feat: enable custom api key and auth enable / disable --- src/mcpm/commands/router.py | 55 +++----- src/mcpm/router/app.py | 7 +- src/mcpm/router/router.py | 174 ++++++++++--------------- src/mcpm/router/router_config.py | 17 ++- src/mcpm/router/transport.py | 13 +- src/mcpm/utils/config.py | 8 +- tests/test_router.py | 211 ++----------------------------- 7 files changed, 129 insertions(+), 356 deletions(-) diff --git a/src/mcpm/commands/router.py b/src/mcpm/commands/router.py index 8e6a17bb..ce5f3c3e 100644 --- a/src/mcpm/commands/router.py +++ b/src/mcpm/commands/router.py @@ -9,7 +9,6 @@ import socket import subprocess import sys -import uuid import click import psutil @@ -138,13 +137,8 @@ def start_router(verbose): config = config_manager.get_router_config() host = config["host"] port = config["port"] - - # Check if we have an API key, if not, create one - share_config = config_manager.read_share_config() - if share_config.get("api_key") is None: - api_key = secrets.token_urlsafe(32) - config_manager.save_share_config(api_key=api_key) - console.print("[bold green]Created API key for router authentication[/]") + auth_enabled = config.get("auth_enabled", False) + api_key = config.get("api_key") # prepare uvicorn command uvicorn_cmd = [ @@ -201,8 +195,7 @@ def start_router(verbose): # Display connection instructions console.print("\n[bold cyan]Connection Information:[/]") - # Get API key if available - api_key = config_manager.read_share_config().get("api_key") + api_key = api_key if auth_enabled else None # Show URL with or without authentication based on API key availability if api_key: @@ -229,8 +222,9 @@ def start_router(verbose): @click.option( "--auth/--no-auth", default=True, is_flag=True, help="Enable/disable API key authentication (default: enabled)" ) +@click.option("-s", "--secret", type=str, help="Secret key for authentication") @click.help_option("-h", "--help") -def set_router_config(host, port, address, auth): +def set_router_config(host, port, address, auth, secret: str | None = None): """Set MCPRouter global configuration. Example: @@ -253,34 +247,23 @@ def set_router_config(host, port, address, auth): host = host or current_config["host"] port = port or current_config["port"] share_address = address or current_config["share_address"] - - # Handle authentication setting - share_config = config_manager.read_share_config() - current_api_key = share_config.get("api_key") + api_key = secret if auth: # Enable authentication - if current_api_key is None: + if api_key is None: # Generate a new API key if authentication is enabled but no key exists api_key = secrets.token_urlsafe(32) - config_manager.save_share_config( - share_url=share_config.get("url"), share_pid=share_config.get("pid"), api_key=api_key - ) console.print("[bold green]API key authentication enabled.[/] Generated new API key.") else: - console.print("[bold green]API key authentication enabled.[/] Using existing API key.") + console.print("[bold green]API key authentication enabled.[/] Using provided API key.") else: # Disable authentication by clearing the API key - if current_api_key is not None: - config_manager.save_share_config( - share_url=share_config.get("url"), share_pid=share_config.get("pid"), api_key=None - ) - console.print("[bold yellow]API key authentication disabled.[/]") - else: - console.print("[bold yellow]API key authentication was already disabled.[/]") + api_key = None + console.print("[bold yellow]API key authentication disabled.[/]") # save router config - if config_manager.save_router_config(host, port, share_address): + if config_manager.save_router_config(host, port, share_address, api_key=api_key, auth_enabled=auth): console.print( f"[bold green]Router configuration updated:[/] host={host}, port={port}, share_address={share_address}" ) @@ -388,7 +371,7 @@ def router_status(): if share_config.get("pid"): if not is_process_running(share_config["pid"]): console.print("[yellow]Share link is not active, cleaning.[/]") - ConfigManager().save_share_config(share_url=None, share_pid=None, api_key=None) + ConfigManager().save_share_config(share_url=None, share_pid=None) console.print("[green]Share link cleaned[/]") else: console.print( @@ -448,17 +431,17 @@ def share(address, profile, http): tunnel = Tunnel(remote_host, remote_port, config["host"], config["port"], secrets.token_urlsafe(32), http, None) share_url = tunnel.start_tunnel() share_pid = tunnel.proc.pid if tunnel.proc else None - # generate random api key - api_key = str(uuid.uuid4()) - console.print(f"[bold green]Generated secret for share link: {api_key}[/]") + api_key = config.get("api_key") if config.get("auth_enabled") else None share_url = share_url + "/sse" # save share pid and link to config - config_manager.save_share_config(share_url, share_pid, api_key) + config_manager.save_share_config(share_url, share_pid) profile = profile or "" # print share link console.print(f"[bold green]Router is sharing at {share_url}[/]") - console.print(f"[green]Your profile can be accessed with the url {share_url}?s={api_key}&profile={profile}[/]\n") + console.print( + f"[green]Your profile can be accessed with the url {share_url}?profile={profile}{f'&s={api_key}' if api_key else ''}[/]\n" + ) console.print( "[bold yellow]Be careful about the share link, it will be exposed to the public. Make sure to share to trusted users only.[/]" ) @@ -471,14 +454,14 @@ def try_clear_share(): if share_config["url"]: try: console.print("[bold yellow]Disabling share link...[/]") - config_manager.save_share_config(share_url=None, share_pid=None, api_key=None) + config_manager.save_share_config(share_url=None, share_pid=None) console.print("[bold green]Share link disabled[/]") if share_config["pid"]: os.kill(share_config["pid"], signal.SIGTERM) except OSError as e: if e.errno == 3: # "No such process" console.print("[yellow]Share process does not exist, cleaning up share config...[/]") - config_manager.save_share_config(share_url=None, share_pid=None, api_key=None) + config_manager.save_share_config(share_url=None, share_pid=None) else: console.print(f"[bold red]Error:[/] Failed to stop share link: {e}") diff --git a/src/mcpm/router/app.py b/src/mcpm/router/app.py index 27862a8e..42830018 100644 --- a/src/mcpm/router/app.py +++ b/src/mcpm/router/app.py @@ -16,6 +16,7 @@ from mcpm.monitor.event import monitor from mcpm.router.router import MCPRouter from mcpm.router.transport import RouterSseTransport +from mcpm.utils.config import ConfigManager from mcpm.utils.platform import get_log_directory LOG_DIR = get_log_directory("mcpm") @@ -30,8 +31,12 @@ ) logger = logging.getLogger("mcpm.router.daemon") +config = ConfigManager().get_router_config() +api_key = config.get("api_key") +auth_enabled = config.get("auth_enabled", False) + router = MCPRouter(reload_server=True) -sse = RouterSseTransport("/messages/") +sse = RouterSseTransport("/messages/", api_key=api_key if auth_enabled else None) class NoOpsResponse(Response): diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index 63097cf6..bc07bf44 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -17,7 +17,7 @@ from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.routing import Mount, Route -from starlette.types import AppType, Lifespan +from starlette.types import Lifespan from mcpm.monitor.base import AccessEventType from mcpm.monitor.event import trace_event @@ -51,15 +51,10 @@ class MCPRouter: # Initialize with custom router configuration router_config = RouterConfig( - host="localhost", - port=8080, - share_address="custom.share.address:8080", - api_key="your-api-key" + api_key="your-api-key", + auth_enabled=True ) router = MCPRouter(router_config=router_config) - - # Create a global config from the router's configuration - router.create_global_config() ``` """ @@ -88,28 +83,11 @@ def __init__( self.watcher: Optional[ConfigWatcher] = None if reload_server: self.watcher = ConfigWatcher(self.profile_manager.profile_path) + if router_config is None: + config = ConfigManager().get_router_config() + router_config = RouterConfig(api_key=config.get("api_key"), auth_enabled=config.get("auth_enabled", False)) self.router_config = router_config if router_config is not None else RouterConfig() - def create_global_config(self) -> None: - """ - Create a global configuration from the router's configuration. - This is useful if you want to initialize the router with a config - but also want that config to be available globally. - """ - # Skip if router_config is None or there's no explicit api_key set - if self.router_config is None or self.router_config.api_key is None: - return - - config_manager = ConfigManager() - - # Save the API key to the global config - config_manager.save_share_config(api_key=self.router_config.api_key) - - # Save router configuration to the global config - config_manager.save_router_config( - self.router_config.host, self.router_config.port, self.router_config.share_address - ) - def get_unique_servers(self) -> list[ServerConfig]: profiles = self.profile_manager.list_profiles() name_to_server = {server.name: server for server_list in profiles.values() for server in server_list} @@ -178,87 +156,75 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # Collect server tools, prompts, and resources if response.capabilities.tools: tools = await client.session.list_tools() # type: ignore - # Extract ListToolsResult from ServerResult - tools_result = tools.root - if isinstance(tools_result, types.ListToolsResult): - for tool in tools_result.tools: - # To make sure tool name is unique across all servers - tool_name = tool.name - if tool_name in self.capabilities_to_server_id["tools"]: - if self.router_config.strict: - raise ValueError( - f"Tool {tool_name} already exists. Please use unique tool names across all servers." - ) - else: - # Auto resolve by adding server name prefix - tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}" - self.capabilities_to_server_id["tools"][tool_name] = server_id - self.tools_mapping[tool_name] = tool + for tool in tools.tools: + # To make sure tool name is unique across all servers + tool_name = tool.name + if tool_name in self.capabilities_to_server_id["tools"]: + if self.router_config.strict: + raise ValueError( + f"Tool {tool_name} already exists. Please use unique tool names across all servers." + ) + else: + # Auto resolve by adding server name prefix + tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}" + self.capabilities_to_server_id["tools"][tool_name] = server_id + self.tools_mapping[tool_name] = tool if response.capabilities.prompts: prompts = await client.session.list_prompts() # type: ignore - # Extract ListPromptsResult from ServerResult - prompts_result = prompts.root - if isinstance(prompts_result, types.ListPromptsResult): - for prompt in prompts_result.prompts: - # To make sure prompt name is unique across all servers - prompt_name = prompt.name - if prompt_name in self.capabilities_to_server_id["prompts"]: - if self.router_config.strict: - raise ValueError( - f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers." - ) - else: - # Auto resolve by adding server name prefix - prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}" - self.prompts_mapping[prompt_name] = prompt - self.capabilities_to_server_id["prompts"][prompt_name] = server_id + for prompt in prompts.prompts: + # To make sure prompt name is unique across all servers + prompt_name = prompt.name + if prompt_name in self.capabilities_to_server_id["prompts"]: + if self.router_config.strict: + raise ValueError( + f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers." + ) + else: + # Auto resolve by adding server name prefix + prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}" + self.prompts_mapping[prompt_name] = prompt + self.capabilities_to_server_id["prompts"][prompt_name] = server_id if response.capabilities.resources: resources = await client.session.list_resources() # type: ignore - # Extract ListResourcesResult from ServerResult - resources_result = resources.root - if isinstance(resources_result, types.ListResourcesResult): - for resource in resources_result.resources: - # To make sure resource URI is unique across all servers - resource_uri = resource.uri - if str(resource_uri) in self.capabilities_to_server_id["resources"]: - if self.router_config.strict: - raise ValueError( - f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers." - ) - else: - # Auto resolve by adding server name prefix - host = resource_uri.host - resource_uri = AnyUrl.build( - host=f"{server_id}{RESOURCE_SPLITOR}{host}", - scheme=resource_uri.scheme, - path=resource_uri.path, - username=resource_uri.username, - password=resource_uri.password, - port=resource_uri.port, - query=resource_uri.query, - fragment=resource_uri.fragment, - ) + for resource in resources.resources: + # To make sure resource URI is unique across all servers + resource_uri = resource.uri + if str(resource_uri) in self.capabilities_to_server_id["resources"]: + if self.router_config.strict: + raise ValueError( + f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers." + ) + else: + # Auto resolve by adding server name prefix + host = resource_uri.host + resource_uri = AnyUrl.build( + host=f"{server_id}{RESOURCE_SPLITOR}{host}", + scheme=resource_uri.scheme, + path=resource_uri.path, + username=resource_uri.username, + password=resource_uri.password, + port=resource_uri.port, + query=resource_uri.query, + fragment=resource_uri.fragment, + ) self.resources_mapping[str(resource_uri)] = resource self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id resources_templates = await client.session.list_resource_templates() # type: ignore - # Extract ListResourceTemplatesResult from ServerResult - templates_result = resources_templates.root - if isinstance(templates_result, types.ListResourceTemplatesResult): - for resource_template in templates_result.resourceTemplates: - # To make sure resource template URI is unique across all servers - resource_template_uri_template = resource_template.uriTemplate - if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]: - if self.router_config.strict: - raise ValueError( - f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers." - ) - else: - # Auto resolve by adding server name prefix - resource_template_uri_template = ( - f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}" - ) + for resource_template in resources_templates.resourceTemplates: + # To make sure resource template URI is unique across all servers + resource_template_uri_template = resource_template.uriTemplate + if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]: + if self.router_config.strict: + raise ValueError( + f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers." + ) + else: + # Auto resolve by adding server name prefix + resource_template_uri_template = ( + f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}" + ) self.resources_templates_mapping[resource_template_uri_template] = resource_template self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id @@ -544,7 +510,7 @@ async def _initialize_server_capabilities(self): async def get_sse_server_app( self, allow_origins: t.Optional[t.List[str]] = None, include_lifespan: bool = True - ) -> AppType: + ) -> Starlette: """ Get the SSE server app. @@ -558,7 +524,7 @@ async def get_sse_server_app( await self.initialize_router() # Pass the API key to the RouterSseTransport - api_key = None if self.router_config is None else self.router_config.api_key + api_key = None if not self.router_config.auth_enabled else self.router_config.api_key sse = RouterSseTransport("/messages/", api_key=api_key) async def handle_sse(request: Request) -> None: @@ -573,11 +539,11 @@ async def handle_sse(request: Request) -> None: self.aggregated_server.initialization_options, ) - lifespan_handler: t.Optional[Lifespan[AppType]] = None + lifespan_handler: t.Optional[Lifespan[Starlette]] = None if include_lifespan: @asynccontextmanager - async def lifespan(app: AppType): + async def lifespan(app: Starlette): yield await self.shutdown() diff --git a/src/mcpm/router/router_config.py b/src/mcpm/router/router_config.py index 92ae2228..93980f70 100644 --- a/src/mcpm/router/router_config.py +++ b/src/mcpm/router/router_config.py @@ -1,8 +1,6 @@ from typing import Optional -from pydantic import BaseModel - -from mcpm.utils.config import DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SHARE_ADDRESS +from pydantic import BaseModel, field_validator class RouterConfig(BaseModel): @@ -10,8 +8,13 @@ class RouterConfig(BaseModel): Router configuration model for MCPRouter """ - host: str = DEFAULT_HOST - port: int = DEFAULT_PORT - share_address: str = DEFAULT_SHARE_ADDRESS - api_key: Optional[str] = None strict: bool = False + api_key: Optional[str] = None + auth_enabled: bool = False + + @field_validator("api_key", mode="after") + def check_api_key(cls, v, info): + # info is ValidationInfo in pydantic v2; info.data is the dict of parsed values + if info.data.get("auth_enabled") and v is None: + raise ValueError("api_key must be provided when auth_enabled is True") + return v diff --git a/src/mcpm/router/transport.py b/src/mcpm/router/transport.py index ac3399c5..032e5f93 100644 --- a/src/mcpm/router/transport.py +++ b/src/mcpm/router/transport.py @@ -258,14 +258,11 @@ def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool: share_config = config_manager.read_share_config() router_config = config_manager.get_router_config() host_name = urlsplit(host).hostname - share_host_name = urlsplit(share_config["url"]).hostname - if share_config["url"] and (host_name == share_host_name or host_name != router_config["host"]): - share_api_key = share_config["api_key"] - if api_key == share_api_key: - return True + if share_config["url"] and host_name != router_config["host"]: + if api_key != self.api_key: + return False except Exception as e: logger.error(f"Failed to validate API key: {e}") + return False - # If we reach here, the API key is invalid - logger.warning("Unauthorized API key") - return False + return True diff --git a/src/mcpm/utils/config.py b/src/mcpm/utils/config.py index 77a7c466..3492d6a0 100644 --- a/src/mcpm/utils/config.py +++ b/src/mcpm/utils/config.py @@ -127,7 +127,7 @@ def get_router_config(self): return router_config - def save_router_config(self, host, port, share_address): + def save_router_config(self, host, port, share_address, api_key: str | None = None, auth_enabled: bool = False): """save router configuration to config file""" router_config = self.get_config().get("router", {}) @@ -135,12 +135,14 @@ def save_router_config(self, host, port, share_address): router_config["host"] = host router_config["port"] = port router_config["share_address"] = share_address + router_config["api_key"] = api_key + router_config["auth_enabled"] = auth_enabled # save config return self.set_config("router", router_config) - def save_share_config(self, share_url: str | None = None, share_pid: int | None = None, api_key: str | None = None): - return self.set_config("share", {"url": share_url, "pid": share_pid, "api_key": api_key}) + def save_share_config(self, share_url: str | None = None, share_pid: int | None = None): + return self.set_config("share", {"url": share_url, "pid": share_pid}) def read_share_config(self) -> Dict[str, Any]: return self.get_config().get("share", {}) diff --git a/tests/test_router.py b/tests/test_router.py index dc7fa895..d4e80f16 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -2,17 +2,16 @@ Tests for the router module """ -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from mcp import InitializeResult -from mcp.types import ListToolsResult, ServerCapabilities, ServerResult, Tool, ToolsCapability +from mcp.types import ListToolsResult, ServerCapabilities, Tool, ToolsCapability from mcpm.router.client_connection import ServerConnection from mcpm.router.router import MCPRouter from mcpm.router.router_config import RouterConfig from mcpm.schemas.server_config import SSEServerConfig -from mcpm.utils.config import TOOL_SPLITOR @pytest.fixture @@ -37,13 +36,15 @@ def mock_server_connection(): mock_session = AsyncMock() # Create a valid tool with proper inputSchema structure mock_tool = Tool(name="test-tool", description="A test tool", inputSchema={"type": "object", "properties": {}}) - # Create a ListToolsResult to be the root of ServerResult + # Create a ListToolsResult to be returned directly tools_result = ListToolsResult(tools=[mock_tool]) - # Create a ServerResult with ListToolsResult as its root - mock_list_tools_result = ServerResult(root=tools_result) - mock_session.list_tools = AsyncMock(return_value=mock_list_tools_result) - mock_conn.session = mock_session + mock_session.list_tools = AsyncMock(return_value=tools_result) + # If you have prompts/resources, mock them similarly: + mock_session.list_prompts = AsyncMock(return_value=MagicMock(prompts=[])) + mock_session.list_resources = AsyncMock(return_value=MagicMock(resources=[])) + mock_session.list_resource_templates = AsyncMock(return_value=MagicMock(resourceTemplates=[])) + mock_conn.session = mock_session return mock_conn @@ -58,9 +59,7 @@ async def test_router_init(): assert router.router_config.strict is False # Test with custom values - config = RouterConfig( - host="custom-host", port=9000, share_address="custom-share-address", api_key="test-api-key", strict=True - ) + config = RouterConfig(api_key="test-api-key", strict=True) router = MCPRouter( reload_server=True, router_config=config, @@ -72,27 +71,6 @@ async def test_router_init(): assert router.router_config.strict is True -def test_create_global_config(): - """Test creating a global config from router config""" - config = RouterConfig(host="custom-host", port=9000, share_address="custom-share-address", api_key="test-api-key") - - with patch("mcpm.router.router.ConfigManager") as mock_config_manager: - mock_instance = Mock() - mock_config_manager.return_value = mock_instance - - # Test without router_config - router = MCPRouter() - router.create_global_config() - mock_instance.save_share_config.assert_not_called() - mock_instance.save_router_config.assert_not_called() - - # Test with router_config - router = MCPRouter(router_config=config) - router.create_global_config() - mock_instance.save_share_config.assert_called_once_with(api_key="test-api-key") - mock_instance.save_router_config.assert_called_once_with("custom-host", 9000, "custom-share-address") - - @pytest.mark.asyncio async def test_add_server(mock_server_connection): """Test adding a server to the router""" @@ -141,137 +119,6 @@ async def test_add_server_unhealthy(): await router.add_server("unhealthy-server", server_config) -@pytest.mark.asyncio -async def test_add_server_duplicate_tool_strict(): - """Test adding a server with duplicate tool name in strict mode""" - router = MCPRouter(router_config=RouterConfig(strict=True)) - - # Mock get_active_servers to return all server IDs - def mock_get_active_servers(_profile): - return list(router.server_sessions.keys()) - - # Patch the _patch_handler_func method to use our mock - with patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler: - mock_patch_handler.return_value.get_active_servers = mock_get_active_servers - - server_config = SSEServerConfig(name="test-server", url="http://localhost:8080/sse") - - # Add first server with a tool - mock_conn1 = MagicMock(spec=ServerConnection) - mock_conn1.healthy.return_value = True - mock_conn1.request_for_shutdown = AsyncMock() - - # Create valid ServerCapabilities with ToolsCapability - tools_capability = ToolsCapability(listChanged=False) - capabilities = ServerCapabilities( - prompts=None, resources=None, tools=tools_capability, logging=None, experimental={} - ) - - mock_conn1.session_initialized_response = InitializeResult( - protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "test-server", "version": "1.0.0"} - ) - - mock_session1 = AsyncMock() - mock_tool = Tool( - name="duplicate-tool", description="A test tool", inputSchema={"type": "object", "properties": {}} - ) - # Create a ListToolsResult to be the root of ServerResult - tools_result = ListToolsResult(tools=[mock_tool]) - # Create a ServerResult with ListToolsResult as its root - mock_list_tools_result = ServerResult(root=tools_result) - mock_session1.list_tools = AsyncMock(return_value=mock_list_tools_result) - mock_conn1.session = mock_session1 - - # Add second server with same tool name - mock_conn2 = MagicMock(spec=ServerConnection) - mock_conn2.healthy.return_value = True - mock_conn2.request_for_shutdown = AsyncMock() - - mock_conn2.session_initialized_response = InitializeResult( - protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "second-server", "version": "1.0.0"} - ) - - mock_session2 = AsyncMock() - mock_session2.list_tools = AsyncMock(return_value=mock_list_tools_result) - mock_conn2.session = mock_session2 - - with patch("mcpm.router.router.ServerConnection", side_effect=[mock_conn1, mock_conn2]): - # Add first server should succeed - await router.add_server("test-server", server_config) - - # Add second server with duplicate tool should fail in strict mode - with pytest.raises(ValueError, match="Tool duplicate-tool already exists"): - await router.add_server("second-server", server_config) - - -@pytest.mark.asyncio -async def test_add_server_duplicate_tool_non_strict(): - """Test adding a server with duplicate tool name in non-strict mode""" - router = MCPRouter(router_config=RouterConfig(strict=False)) - - # Mock get_active_servers to return all server IDs - def mock_get_active_servers(_profile): - return list(router.server_sessions.keys()) - - # Patch the _patch_handler_func method to use our mock - with patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler: - mock_patch_handler.return_value.get_active_servers = mock_get_active_servers - - server_config = SSEServerConfig(name="test-server", url="http://localhost:8080/sse") - second_server_config = SSEServerConfig(name="second-server", url="http://localhost:8081/sse") - - # Add first server with a tool - mock_conn1 = MagicMock(spec=ServerConnection) - mock_conn1.healthy.return_value = True - mock_conn1.request_for_shutdown = AsyncMock() - - # Create valid ServerCapabilities with ToolsCapability - tools_capability = ToolsCapability(listChanged=False) - capabilities = ServerCapabilities( - prompts=None, resources=None, tools=tools_capability, logging=None, experimental={} - ) - - mock_conn1.session_initialized_response = InitializeResult( - protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "test-server", "version": "1.0.0"} - ) - - mock_session1 = AsyncMock() - mock_tool = Tool( - name="duplicate-tool", description="A test tool", inputSchema={"type": "object", "properties": {}} - ) - # Create a ListToolsResult to be the root of ServerResult - tools_result = ListToolsResult(tools=[mock_tool]) - # Create a ServerResult with ListToolsResult as its root - mock_list_tools_result = ServerResult(root=tools_result) - mock_session1.list_tools = AsyncMock(return_value=mock_list_tools_result) - mock_conn1.session = mock_session1 - - # Add second server with same tool name - mock_conn2 = MagicMock(spec=ServerConnection) - mock_conn2.healthy.return_value = True - mock_conn2.request_for_shutdown = AsyncMock() - - mock_conn2.session_initialized_response = InitializeResult( - protocolVersion="1.0", capabilities=capabilities, serverInfo={"name": "second-server", "version": "1.0.0"} - ) - - mock_session2 = AsyncMock() - mock_session2.list_tools = AsyncMock(return_value=mock_list_tools_result) - mock_conn2.session = mock_session2 - - with patch("mcpm.router.router.ServerConnection", side_effect=[mock_conn1, mock_conn2]): - # Add first server - await router.add_server("test-server", server_config) - assert "duplicate-tool" in router.tools_mapping - assert router.capabilities_to_server_id["tools"]["duplicate-tool"] == "test-server" - - # Add second server with duplicate tool - should prefix the tool name - await router.add_server("second-server", second_server_config) - prefixed_tool_name = f"second-server{TOOL_SPLITOR}duplicate-tool" - assert prefixed_tool_name in router.capabilities_to_server_id["tools"] - assert router.capabilities_to_server_id["tools"][prefixed_tool_name] == "second-server" - - @pytest.mark.asyncio async def test_remove_server(): """Test removing a server from the router""" @@ -429,24 +276,9 @@ async def test_router_sse_transport_with_api_key(): @pytest.mark.asyncio async def test_get_sse_server_app_with_api_key(): - """Test that the API key is passed to RouterSseTransport when creating the server app""" - router = MCPRouter(router_config=RouterConfig(api_key="test-api-key")) - - # Patch the RouterSseTransport constructor and get_active_servers method - with ( - patch("mcpm.router.router.RouterSseTransport") as mock_transport, - patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler, - ): - # Set up mocks for initialization - def mock_get_active_servers(_profile): - return list(router.server_sessions.keys()) - - mock_patch_handler.return_value.get_active_servers = mock_get_active_servers - - # Call the method + with patch("mcpm.router.router.RouterSseTransport") as mock_transport: + router = MCPRouter(router_config=RouterConfig(auth_enabled=True, api_key="test-api-key")) await router.get_sse_server_app() - - # Check that RouterSseTransport was created with the correct API key mock_transport.assert_called_once() call_kwargs = mock_transport.call_args[1] assert call_kwargs.get("api_key") == "test-api-key" @@ -454,24 +286,9 @@ def mock_get_active_servers(_profile): @pytest.mark.asyncio async def test_get_sse_server_app_without_api_key(): - """Test that None is passed to RouterSseTransport when no API key is provided""" - router = MCPRouter() # No API key or router_config - - # Patch the RouterSseTransport constructor and get_active_servers method - with ( - patch("mcpm.router.router.RouterSseTransport") as mock_transport, - patch.object(router, "_patch_handler_func", wraps=router._patch_handler_func) as mock_patch_handler, - ): - # Set up mocks for initialization - def mock_get_active_servers(_profile): - return list(router.server_sessions.keys()) - - mock_patch_handler.return_value.get_active_servers = mock_get_active_servers - - # Call the method + with patch("mcpm.router.router.RouterSseTransport") as mock_transport: + router = MCPRouter(router_config=RouterConfig(auth_enabled=False, api_key="custom-secret")) await router.get_sse_server_app() - - # Check that RouterSseTransport was created with api_key=None mock_transport.assert_called_once() call_kwargs = mock_transport.call_args[1] assert call_kwargs.get("api_key") is None From 78223505fbd03b8f47dc49f96993d78a9839e121 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Tue, 29 Apr 2025 10:17:48 +0800 Subject: [PATCH 12/14] fix: share config key error --- src/mcpm/commands/router.py | 8 ++++---- src/mcpm/router/transport.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/mcpm/commands/router.py b/src/mcpm/commands/router.py index ce5f3c3e..351c25d0 100644 --- a/src/mcpm/commands/router.py +++ b/src/mcpm/commands/router.py @@ -451,12 +451,12 @@ def try_clear_share(): console.print("[bold yellow]Clearing share config...[/]") config_manager = ConfigManager() share_config = config_manager.read_share_config() - if share_config["url"]: + if share_config.get("url"): try: console.print("[bold yellow]Disabling share link...[/]") config_manager.save_share_config(share_url=None, share_pid=None) console.print("[bold green]Share link disabled[/]") - if share_config["pid"]: + if share_config.get("pid"): os.kill(share_config["pid"], signal.SIGTERM) except OSError as e: if e.errno == 3: # "No such process" @@ -473,11 +473,11 @@ def stop_share(): # check if there is a share link already running config_manager = ConfigManager() share_config = config_manager.read_share_config() - if not share_config["url"]: + if not share_config.get("url"): console.print("[yellow]No share link is active.[/]") return - pid = share_config["pid"] + pid = share_config.get("pid") if not pid: console.print("[yellow]No share link is active.[/]") return diff --git a/src/mcpm/router/transport.py b/src/mcpm/router/transport.py index 032e5f93..7e7693fc 100644 --- a/src/mcpm/router/transport.py +++ b/src/mcpm/router/transport.py @@ -258,7 +258,7 @@ def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool: share_config = config_manager.read_share_config() router_config = config_manager.get_router_config() host_name = urlsplit(host).hostname - if share_config["url"] and host_name != router_config["host"]: + if share_config.get("url") and host_name != router_config["host"]: if api_key != self.api_key: return False except Exception as e: From aeb2aa3b426d7847192b4f9135a56b2e7bda2037 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Tue, 29 Apr 2025 11:17:03 +0800 Subject: [PATCH 13/14] test: fix test --- src/mcpm/router/transport.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mcpm/router/transport.py b/src/mcpm/router/transport.py index 7e7693fc..f8a277f7 100644 --- a/src/mcpm/router/transport.py +++ b/src/mcpm/router/transport.py @@ -255,10 +255,9 @@ def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool: host = get_key_from_scope(scope, key_name="host") or "" if not host.startswith("http"): host = f"http://{host}" - share_config = config_manager.read_share_config() router_config = config_manager.get_router_config() host_name = urlsplit(host).hostname - if share_config.get("url") and host_name != router_config["host"]: + if host_name != router_config["host"]: if api_key != self.api_key: return False except Exception as e: From 52b7df1ced0a0f7b361eb9abba9403d882239552 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Wed, 30 Apr 2025 11:36:25 +0800 Subject: [PATCH 14/14] style: ruff --- src/mcpm/router/router.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index dd1124c1..908a167f 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -23,7 +23,6 @@ from mcpm.monitor.base import AccessEventType from mcpm.monitor.event import trace_event from mcpm.profile.profile_config import ProfileConfigManager -from mcpm.schemas.server_config import ServerConfig from mcpm.utils.config import ( PROMPT_SPLITOR, RESOURCE_SPLITOR,