Skip to content

Commit a9c942a

Browse files
committed
refactor router config
1 parent 96b4d00 commit a9c942a

File tree

6 files changed

+84
-66
lines changed

6 files changed

+84
-66
lines changed

.cursor/rules/pytest.mdc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
description:
3+
globs: *.py
4+
alwaysApply: false
5+
---
6+
always run pytest at the end of a major change

src/mcpm/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# Import version from internal module
66
# Import router module
77
from . import router
8+
from .router.router import MCPRouter
9+
from .router.router_config import RouterConfig
810
from .version import __version__
911

1012
# Define what symbols are exported from this package
11-
__all__ = ["__version__", "router"]
13+
__all__ = ["__version__", "router", "MCPRouter", "RouterConfig"]

src/mcpm/router/router.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
from mcpm.profile.profile_config import ProfileConfigManager
2525
from mcpm.schemas.server_config import ServerConfig
2626
from mcpm.utils.config import (
27-
DEFAULT_HOST,
28-
DEFAULT_PORT,
29-
DEFAULT_SHARE_ADDRESS,
3027
PROMPT_SPLITOR,
3128
RESOURCE_SPLITOR,
3229
RESOURCE_TEMPLATE_SPLITOR,
@@ -35,6 +32,7 @@
3532
)
3633

3734
from .client_connection import ServerConnection
35+
from .router_config import RouterConfig
3836
from .transport import RouterSseTransport
3937
from .watcher import ConfigWatcher
4038

@@ -49,15 +47,16 @@ class MCPRouter:
4947
Example:
5048
```python
5149
# Initialize with a custom API key
52-
router = MCPRouter(api_key="your-api-key")
50+
router = MCPRouter(router_config=RouterConfig(api_key="your-api-key"))
5351
5452
# Initialize with custom router configuration
55-
router_config = {
56-
"host": "localhost",
57-
"port": 8080,
58-
"share_address": "custom.share.address:8080"
59-
}
60-
router = MCPRouter(api_key="your-api-key", router_config=router_config)
53+
router_config = RouterConfig(
54+
host="localhost",
55+
port=8080,
56+
share_address="custom.share.address:8080",
57+
api_key="your-api-key"
58+
)
59+
router = MCPRouter(router_config=router_config)
6160
6261
# Create a global config from the router's configuration
6362
router.create_global_config()
@@ -68,18 +67,13 @@ def __init__(
6867
self,
6968
reload_server: bool = False,
7069
profile_path: str | None = None,
71-
strict: bool = False,
72-
api_key: str | None = None,
73-
router_config: dict | None = None,
70+
router_config: RouterConfig | None = None,
7471
) -> None:
7572
"""
7673
Initialize the router.
7774
7875
:param reload_server: Whether to reload the server when the config changes
7976
:param profile_path: Path to the profile file
80-
:param strict: Whether to use strict mode for duplicated tool name.
81-
If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix
82-
:param api_key: Optional API key to use for authentication.
8377
:param router_config: Optional router configuration to use instead of the global config
8478
"""
8579
self.server_sessions: t.Dict[str, ServerConnection] = {}
@@ -94,27 +88,27 @@ def __init__(
9488
self.watcher: Optional[ConfigWatcher] = None
9589
if reload_server:
9690
self.watcher = ConfigWatcher(self.profile_manager.profile_path)
97-
self.strict: bool = strict
98-
self.api_key = api_key
99-
self.router_config = router_config
91+
self.router_config = router_config if router_config is not None else RouterConfig()
10092

10193
def create_global_config(self) -> None:
10294
"""
10395
Create a global configuration from the router's configuration.
10496
This is useful if you want to initialize the router with a config
10597
but also want that config to be available globally.
10698
"""
107-
if self.api_key is not None:
108-
config_manager = ConfigManager()
109-
# Save the API key to the global config
110-
config_manager.save_share_config(api_key=self.api_key)
111-
112-
# If router_config is provided, save it to the global config
113-
if self.router_config is not None:
114-
host = self.router_config.get("host", DEFAULT_HOST)
115-
port = self.router_config.get("port", DEFAULT_PORT)
116-
share_address = self.router_config.get("share_address", DEFAULT_SHARE_ADDRESS)
117-
config_manager.save_router_config(host, port, share_address)
99+
# Skip if router_config is None or there's no explicit api_key set
100+
if self.router_config is None or self.router_config.api_key is None:
101+
return
102+
103+
config_manager = ConfigManager()
104+
105+
# Save the API key to the global config
106+
config_manager.save_share_config(api_key=self.router_config.api_key)
107+
108+
# Save router configuration to the global config
109+
config_manager.save_router_config(
110+
self.router_config.host, self.router_config.port, self.router_config.share_address
111+
)
118112

119113
def get_unique_servers(self) -> list[ServerConfig]:
120114
profiles = self.profile_manager.list_profiles()
@@ -191,7 +185,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
191185
# To make sure tool name is unique across all servers
192186
tool_name = tool.name
193187
if tool_name in self.capabilities_to_server_id["tools"]:
194-
if self.strict:
188+
if self.router_config.strict:
195189
raise ValueError(
196190
f"Tool {tool_name} already exists. Please use unique tool names across all servers."
197191
)
@@ -210,7 +204,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
210204
# To make sure prompt name is unique across all servers
211205
prompt_name = prompt.name
212206
if prompt_name in self.capabilities_to_server_id["prompts"]:
213-
if self.strict:
207+
if self.router_config.strict:
214208
raise ValueError(
215209
f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers."
216210
)
@@ -229,7 +223,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
229223
# To make sure resource URI is unique across all servers
230224
resource_uri = resource.uri
231225
if str(resource_uri) in self.capabilities_to_server_id["resources"]:
232-
if self.strict:
226+
if self.router_config.strict:
233227
raise ValueError(
234228
f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers."
235229
)
@@ -256,7 +250,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
256250
# To make sure resource template URI is unique across all servers
257251
resource_template_uri_template = resource_template.uriTemplate
258252
if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]:
259-
if self.strict:
253+
if self.router_config.strict:
260254
raise ValueError(
261255
f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers."
262256
)
@@ -564,7 +558,8 @@ async def get_sse_server_app(
564558
await self.initialize_router()
565559

566560
# Pass the API key to the RouterSseTransport
567-
sse = RouterSseTransport("/messages/", api_key=self.api_key)
561+
api_key = None if self.router_config is None else self.router_config.api_key
562+
sse = RouterSseTransport("/messages/", api_key=api_key)
568563

569564
async def handle_sse(request: Request) -> None:
570565
async with sse.connect_sse(

src/mcpm/router/router_config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel
4+
5+
from mcpm.utils.config import DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SHARE_ADDRESS
6+
7+
8+
class RouterConfig(BaseModel):
9+
"""
10+
Router configuration model for MCPRouter
11+
"""
12+
13+
host: str = DEFAULT_HOST
14+
port: int = DEFAULT_PORT
15+
share_address: str = DEFAULT_SHARE_ADDRESS
16+
api_key: Optional[str] = None
17+
strict: bool = False

src/mcpm/router/transport.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,11 @@ def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool:
245245
return True
246246

247247
# If we have a directly provided API key, verify it matches
248-
if self.api_key is not None:
249-
# If API key doesn't match, return False
250-
if api_key != self.api_key:
251-
logger.warning("Unauthorized API key")
252-
return False
248+
if api_key == self.api_key:
253249
return True
254250

255-
# Otherwise, fall back to the original validation logic
251+
# At this point, self.api_key is not None but doesn't match the provided api_key
252+
# Let's check if this is a share URL that needs special validation
256253
try:
257254
config_manager = ConfigManager()
258255
host = get_key_from_scope(scope, key_name="host") or ""
@@ -264,10 +261,11 @@ def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool:
264261
share_host_name = urlsplit(share_config["url"]).hostname
265262
if share_config["url"] and (host_name == share_host_name or host_name != router_config["host"]):
266263
share_api_key = share_config["api_key"]
267-
if api_key != share_api_key:
268-
logger.warning("Unauthorized API key")
269-
return False
264+
if api_key == share_api_key:
265+
return True
270266
except Exception as e:
271267
logger.error(f"Failed to validate API key: {e}")
272-
return False
273-
return True
268+
269+
# If we reach here, the API key is invalid
270+
logger.warning("Unauthorized API key")
271+
return False

tests/test_router.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from mcpm.router.client_connection import ServerConnection
1212
from mcpm.router.router import MCPRouter
13+
from mcpm.router.router_config import RouterConfig
1314
from mcpm.schemas.server_config import SSEServerConfig
1415
from mcpm.utils.config import TOOL_SPLITOR
1516

@@ -53,41 +54,40 @@ async def test_router_init():
5354
router = MCPRouter()
5455
assert router.profile_manager is not None
5556
assert router.watcher is None
56-
assert router.strict is False
57-
assert router.api_key is None
58-
assert router.router_config is None
57+
assert router.router_config is not None
58+
assert router.router_config.strict is False
5959

6060
# Test with custom values
61-
router_config = {"host": "custom-host", "port": 9000}
61+
config = RouterConfig(
62+
host="custom-host", port=9000, share_address="custom-share-address", api_key="test-api-key", strict=True
63+
)
6264
router = MCPRouter(
6365
reload_server=True,
64-
strict=True,
65-
api_key="test-api-key",
66-
router_config=router_config,
66+
router_config=config,
6767
)
6868

6969
assert router.watcher is not None
70-
assert router.strict is True
71-
assert router.api_key == "test-api-key"
72-
assert router.router_config == router_config
70+
assert router.router_config == config
71+
assert router.router_config.api_key == "test-api-key"
72+
assert router.router_config.strict is True
7373

7474

7575
def test_create_global_config():
7676
"""Test creating a global config from router config"""
77-
router_config = {"host": "custom-host", "port": 9000, "share_address": "custom-share-address"}
77+
config = RouterConfig(host="custom-host", port=9000, share_address="custom-share-address", api_key="test-api-key")
7878

7979
with patch("mcpm.router.router.ConfigManager") as mock_config_manager:
8080
mock_instance = Mock()
8181
mock_config_manager.return_value = mock_instance
8282

83-
# Test without API key
84-
router = MCPRouter(router_config=router_config)
83+
# Test without router_config
84+
router = MCPRouter()
8585
router.create_global_config()
8686
mock_instance.save_share_config.assert_not_called()
8787
mock_instance.save_router_config.assert_not_called()
8888

89-
# Test with API key
90-
router = MCPRouter(api_key="test-api-key", router_config=router_config)
89+
# Test with router_config
90+
router = MCPRouter(router_config=config)
9191
router.create_global_config()
9292
mock_instance.save_share_config.assert_called_once_with(api_key="test-api-key")
9393
mock_instance.save_router_config.assert_called_once_with("custom-host", 9000, "custom-share-address")
@@ -144,7 +144,7 @@ async def test_add_server_unhealthy():
144144
@pytest.mark.asyncio
145145
async def test_add_server_duplicate_tool_strict():
146146
"""Test adding a server with duplicate tool name in strict mode"""
147-
router = MCPRouter(strict=True)
147+
router = MCPRouter(router_config=RouterConfig(strict=True))
148148

149149
# Mock get_active_servers to return all server IDs
150150
def mock_get_active_servers(_profile):
@@ -207,7 +207,7 @@ def mock_get_active_servers(_profile):
207207
@pytest.mark.asyncio
208208
async def test_add_server_duplicate_tool_non_strict():
209209
"""Test adding a server with duplicate tool name in non-strict mode"""
210-
router = MCPRouter(strict=False)
210+
router = MCPRouter(router_config=RouterConfig(strict=False))
211211

212212
# Mock get_active_servers to return all server IDs
213213
def mock_get_active_servers(_profile):
@@ -430,7 +430,7 @@ async def test_router_sse_transport_with_api_key():
430430
@pytest.mark.asyncio
431431
async def test_get_sse_server_app_with_api_key():
432432
"""Test that the API key is passed to RouterSseTransport when creating the server app"""
433-
router = MCPRouter(api_key="test-api-key")
433+
router = MCPRouter(router_config=RouterConfig(api_key="test-api-key"))
434434

435435
# Patch the RouterSseTransport constructor and get_active_servers method
436436
with (
@@ -455,7 +455,7 @@ def mock_get_active_servers(_profile):
455455
@pytest.mark.asyncio
456456
async def test_get_sse_server_app_without_api_key():
457457
"""Test that None is passed to RouterSseTransport when no API key is provided"""
458-
router = MCPRouter() # No API key
458+
router = MCPRouter() # No API key or router_config
459459

460460
# Patch the RouterSseTransport constructor and get_active_servers method
461461
with (

0 commit comments

Comments
 (0)