Skip to content

Commit bf21d42

Browse files
xingyaowwopenhands-agentniechenJoJoJoJoJoJoJo
authored
feat(router): Support custom api key and auth enable/disable in router (#106)
* Add support for custom API key and router configuration in MCPRouter * Remove example scripts and update README * Add support for disabling API key validation when api_key is set to None * Remove test files and revert README changes * Revert docstring changes in router.py * Revert README changes * Fix linting issues with Ruff * Add tests for router and profile * refactor router config * add --auth/--no-auth router config * feat: enable custom api key and auth enable / disable * fix: share config key error * test: fix test * style: ruff --------- Co-authored-by: openhands <[email protected]> Co-authored-by: cnie <[email protected]> Co-authored-by: Jonathan Wang <[email protected]>
1 parent aa69c9c commit bf21d42

File tree

10 files changed

+678
-50
lines changed

10 files changed

+678
-50
lines changed

.cursor/rules/pytest.mdc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
description:
3+
globs: *.py
4+
alwaysApply: false
5+
---
6+
always run pytest at the end of a major change
7+
always run ruff lint at then end of any code changes

src/mcpm/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# Import version from internal module
66
# Import router module
77
from . import router
8+
from .router.router import MCPRouter
9+
from .router.router_config import RouterConfig
810
from .version import __version__
911

1012
# Define what symbols are exported from this package
11-
__all__ = ["__version__", "router"]
13+
__all__ = ["__version__", "router", "MCPRouter", "RouterConfig"]

src/mcpm/commands/router.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import socket
1010
import subprocess
1111
import sys
12-
import uuid
1312

1413
import click
1514
import psutil
@@ -41,6 +40,7 @@ def is_process_running(pid):
4140
except Exception:
4241
return False
4342

43+
4444
def is_port_listening(host, port) -> bool:
4545
"""
4646
Check if the specified (host, port) is being listened on.
@@ -133,9 +133,12 @@ def start_router(verbose):
133133
return
134134

135135
# get router config
136-
config = ConfigManager().get_router_config()
136+
config_manager = ConfigManager()
137+
config = config_manager.get_router_config()
137138
host = config["host"]
138139
port = config["port"]
140+
auth_enabled = config.get("auth_enabled", False)
141+
api_key = config.get("api_key")
139142

140143
# prepare uvicorn command
141144
uvicorn_cmd = [
@@ -185,9 +188,28 @@ def start_router(verbose):
185188
pid = process.pid
186189
write_pid_file(pid)
187190

191+
# Display router started information
188192
console.print(f"[bold green]MCPRouter started[/] at http://{host}:{port} (PID: {pid})")
189193
console.print(f"Log file: {log_file}")
190-
console.print("Use 'mcpm router off' to stop the router.")
194+
195+
# Display connection instructions
196+
console.print("\n[bold cyan]Connection Information:[/]")
197+
198+
api_key = api_key if auth_enabled else None
199+
200+
# Show URL with or without authentication based on API key availability
201+
if api_key:
202+
# Show authenticated URL
203+
console.print(f"SSE Server URL: [green]http://{host}:{port}/sse?s={api_key}[/]")
204+
console.print("\n[bold cyan]To use a specific profile with authentication:[/]")
205+
console.print(f"[green]http://{host}:{port}/sse?s={api_key}&profile=<profile_name>[/]")
206+
else:
207+
# Show URL without authentication
208+
console.print(f"SSE Server URL: [green]http://{host}:{port}/sse[/]")
209+
console.print("\n[bold cyan]To use a specific profile:[/]")
210+
console.print(f"[green]http://{host}:{port}/sse?profile=<profile_name>[/]")
211+
212+
console.print("\n[yellow]Use 'mcpm router off' to stop the router.[/]")
191213

192214
except Exception as e:
193215
console.print(f"[bold red]Error:[/] Failed to start MCPRouter: {e}")
@@ -197,17 +219,23 @@ def start_router(verbose):
197219
@click.option("-H", "--host", type=str, help="Host to bind the SSE server to")
198220
@click.option("-p", "--port", type=int, help="Port to bind the SSE server to")
199221
@click.option("-a", "--address", type=str, help="Remote address to share the router")
222+
@click.option(
223+
"--auth/--no-auth", default=True, is_flag=True, help="Enable/disable API key authentication (default: enabled)"
224+
)
225+
@click.option("-s", "--secret", type=str, help="Secret key for authentication")
200226
@click.help_option("-h", "--help")
201-
def set_router_config(host, port, address):
227+
def set_router_config(host, port, address, auth, secret: str | None = None):
202228
"""Set MCPRouter global configuration.
203229
204230
Example:
205231
mcpm router set -H localhost -p 8888
206232
mcpm router set --host 127.0.0.1 --port 9000
233+
mcpm router set --no-auth # disable authentication
234+
mcpm router set --auth # enable authentication
207235
"""
208-
if not host and not port and not address:
236+
if not host and not port and not address and auth is None:
209237
console.print(
210-
"[yellow]No changes were made. Please specify at least one option (--host, --port, or --address)[/]"
238+
"[yellow]No changes were made. Please specify at least one option (--host, --port, --address, --auth/--no-auth)[/]"
211239
)
212240
return
213241

@@ -219,9 +247,23 @@ def set_router_config(host, port, address):
219247
host = host or current_config["host"]
220248
port = port or current_config["port"]
221249
share_address = address or current_config["share_address"]
250+
api_key = secret
251+
252+
if auth:
253+
# Enable authentication
254+
if api_key is None:
255+
# Generate a new API key if authentication is enabled but no key exists
256+
api_key = secrets.token_urlsafe(32)
257+
console.print("[bold green]API key authentication enabled.[/] Generated new API key.")
258+
else:
259+
console.print("[bold green]API key authentication enabled.[/] Using provided API key.")
260+
else:
261+
# Disable authentication by clearing the API key
262+
api_key = None
263+
console.print("[bold yellow]API key authentication disabled.[/]")
222264

223-
# save config
224-
if config_manager.save_router_config(host, port, share_address):
265+
# save router config
266+
if config_manager.save_router_config(host, port, share_address, api_key=api_key, auth_enabled=auth):
225267
console.print(
226268
f"[bold green]Router configuration updated:[/] host={host}, port={port}, share_address={share_address}"
227269
)
@@ -329,7 +371,7 @@ def router_status():
329371
if share_config.get("pid"):
330372
if not is_process_running(share_config["pid"]):
331373
console.print("[yellow]Share link is not active, cleaning.[/]")
332-
ConfigManager().save_share_config(share_url=None, share_pid=None, api_key=None)
374+
ConfigManager().save_share_config(share_url=None, share_pid=None)
333375
console.print("[green]Share link cleaned[/]")
334376
else:
335377
console.print(
@@ -389,17 +431,17 @@ def share(address, profile, http):
389431
tunnel = Tunnel(remote_host, remote_port, config["host"], config["port"], secrets.token_urlsafe(32), http, None)
390432
share_url = tunnel.start_tunnel()
391433
share_pid = tunnel.proc.pid if tunnel.proc else None
392-
# generate random api key
393-
api_key = str(uuid.uuid4())
394-
console.print(f"[bold green]Generated secret for share link: {api_key}[/]")
434+
api_key = config.get("api_key") if config.get("auth_enabled") else None
395435
share_url = share_url + "/sse"
396436
# save share pid and link to config
397-
config_manager.save_share_config(share_url, share_pid, api_key)
437+
config_manager.save_share_config(share_url, share_pid)
398438
profile = profile or "<your_profile>"
399439

400440
# print share link
401441
console.print(f"[bold green]Router is sharing at {share_url}[/]")
402-
console.print(f"[green]Your profile can be accessed with the url {share_url}?s={api_key}&profile={profile}[/]\n")
442+
console.print(
443+
f"[green]Your profile can be accessed with the url {share_url}?profile={profile}{f'&s={api_key}' if api_key else ''}[/]\n"
444+
)
403445
console.print(
404446
"[bold yellow]Be careful about the share link, it will be exposed to the public. Make sure to share to trusted users only.[/]"
405447
)
@@ -409,17 +451,17 @@ def try_clear_share():
409451
console.print("[bold yellow]Clearing share config...[/]")
410452
config_manager = ConfigManager()
411453
share_config = config_manager.read_share_config()
412-
if share_config["url"]:
454+
if share_config.get("url"):
413455
try:
414456
console.print("[bold yellow]Disabling share link...[/]")
415-
config_manager.save_share_config(share_url=None, share_pid=None, api_key=None)
457+
config_manager.save_share_config(share_url=None, share_pid=None)
416458
console.print("[bold green]Share link disabled[/]")
417-
if share_config["pid"]:
459+
if share_config.get("pid"):
418460
os.kill(share_config["pid"], signal.SIGTERM)
419461
except OSError as e:
420462
if e.errno == 3: # "No such process"
421463
console.print("[yellow]Share process does not exist, cleaning up share config...[/]")
422-
config_manager.save_share_config(share_url=None, share_pid=None, api_key=None)
464+
config_manager.save_share_config(share_url=None, share_pid=None)
423465
else:
424466
console.print(f"[bold red]Error:[/] Failed to stop share link: {e}")
425467

@@ -431,11 +473,11 @@ def stop_share():
431473
# check if there is a share link already running
432474
config_manager = ConfigManager()
433475
share_config = config_manager.read_share_config()
434-
if not share_config["url"]:
476+
if not share_config.get("url"):
435477
console.print("[yellow]No share link is active.[/]")
436478
return
437479

438-
pid = share_config["pid"]
480+
pid = share_config.get("pid")
439481
if not pid:
440482
console.print("[yellow]No share link is active.[/]")
441483
return

src/mcpm/router/app.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mcpm.monitor.event import monitor
1717
from mcpm.router.router import MCPRouter
1818
from mcpm.router.transport import RouterSseTransport
19+
from mcpm.utils.config import ConfigManager
1920
from mcpm.utils.platform import get_log_directory
2021

2122
LOG_DIR = get_log_directory("mcpm")
@@ -30,8 +31,12 @@
3031
)
3132
logger = logging.getLogger("mcpm.router.daemon")
3233

34+
config = ConfigManager().get_router_config()
35+
api_key = config.get("api_key")
36+
auth_enabled = config.get("auth_enabled", False)
37+
3338
router = MCPRouter(reload_server=True)
34-
sse = RouterSseTransport("/messages/")
39+
sse = RouterSseTransport("/messages/", api_key=api_key if auth_enabled else None)
3540

3641

3742
class NoOpsResponse(Response):

src/mcpm/router/router.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,23 @@
1717
from starlette.middleware.cors import CORSMiddleware
1818
from starlette.requests import Request
1919
from starlette.routing import Mount, Route
20-
from starlette.types import AppType, Lifespan
20+
from starlette.types import Lifespan
2121

2222
from mcpm.core.schema import ServerConfig
2323
from mcpm.monitor.base import AccessEventType
2424
from mcpm.monitor.event import trace_event
2525
from mcpm.profile.profile_config import ProfileConfigManager
26-
from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, RESOURCE_TEMPLATE_SPLITOR, TOOL_SPLITOR
26+
from mcpm.utils.config import (
27+
PROMPT_SPLITOR,
28+
RESOURCE_SPLITOR,
29+
RESOURCE_TEMPLATE_SPLITOR,
30+
TOOL_SPLITOR,
31+
ConfigManager,
32+
)
2733
from mcpm.utils.errlog_manager import ServerErrorLogManager
2834

2935
from .client_connection import ServerConnection
36+
from .router_config import RouterConfig
3037
from .transport import RouterSseTransport
3138
from .watcher import ConfigWatcher
3239

@@ -37,16 +44,33 @@ class MCPRouter:
3744
"""
3845
A router that aggregates multiple MCP servers (SSE/STDIO) and
3946
exposes them as a single SSE server.
47+
48+
Example:
49+
```python
50+
# Initialize with a custom API key
51+
router = MCPRouter(router_config=RouterConfig(api_key="your-api-key"))
52+
53+
# Initialize with custom router configuration
54+
router_config = RouterConfig(
55+
api_key="your-api-key",
56+
auth_enabled=True
57+
)
58+
router = MCPRouter(router_config=router_config)
59+
```
4060
"""
4161

42-
def __init__(self, reload_server: bool = False, profile_path: str | None = None, strict: bool = False) -> None:
62+
def __init__(
63+
self,
64+
reload_server: bool = False,
65+
profile_path: str | None = None,
66+
router_config: RouterConfig | None = None,
67+
) -> None:
4368
"""
4469
Initialize the router.
4570
4671
:param reload_server: Whether to reload the server when the config changes
4772
:param profile_path: Path to the profile file
48-
:param strict: Whether to use strict mode for duplicated tool name.
49-
If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix
73+
:param router_config: Optional router configuration to use instead of the global config
5074
"""
5175
self.server_sessions: t.Dict[str, ServerConnection] = {}
5276
self.capabilities_mapping: t.Dict[str, t.Dict[str, t.Any]] = defaultdict(dict)
@@ -60,7 +84,10 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = None,
6084
self.watcher: Optional[ConfigWatcher] = None
6185
if reload_server:
6286
self.watcher = ConfigWatcher(self.profile_manager.profile_path)
63-
self.strict: bool = strict
87+
if router_config is None:
88+
config = ConfigManager().get_router_config()
89+
router_config = RouterConfig(api_key=config.get("api_key"), auth_enabled=config.get("auth_enabled", False))
90+
self.router_config = router_config
6491
self.error_log_manager = ServerErrorLogManager()
6592

6693
def get_unique_servers(self) -> list[ServerConfig]:
@@ -137,7 +164,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
137164
# To make sure tool name is unique across all servers
138165
tool_name = tool.name
139166
if tool_name in self.capabilities_to_server_id["tools"]:
140-
if self.strict:
167+
if self.router_config.strict:
141168
raise ValueError(
142169
f"Tool {tool_name} already exists. Please use unique tool names across all servers."
143170
)
@@ -153,7 +180,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
153180
# To make sure prompt name is unique across all servers
154181
prompt_name = prompt.name
155182
if prompt_name in self.capabilities_to_server_id["prompts"]:
156-
if self.strict:
183+
if self.router_config.strict:
157184
raise ValueError(
158185
f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers."
159186
)
@@ -169,7 +196,7 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
169196
# To make sure resource URI is unique across all servers
170197
resource_uri = resource.uri
171198
if str(resource_uri) in self.capabilities_to_server_id["resources"]:
172-
if self.strict:
199+
if self.router_config.strict:
173200
raise ValueError(
174201
f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers."
175202
)
@@ -186,14 +213,14 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
186213
query=resource_uri.query,
187214
fragment=resource_uri.fragment,
188215
)
189-
self.resources_mapping[str(resource_uri)] = resource
190-
self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id
216+
self.resources_mapping[str(resource_uri)] = resource
217+
self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id
191218
resources_templates = await client.session.list_resource_templates() # type: ignore
192219
for resource_template in resources_templates.resourceTemplates:
193220
# To make sure resource template URI is unique across all servers
194221
resource_template_uri_template = resource_template.uriTemplate
195222
if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]:
196-
if self.strict:
223+
if self.router_config.strict:
197224
raise ValueError(
198225
f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers."
199226
)
@@ -202,8 +229,8 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
202229
resource_template_uri_template = (
203230
f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}"
204231
)
205-
self.resources_templates_mapping[resource_template_uri_template] = resource_template
206-
self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id
232+
self.resources_templates_mapping[resource_template_uri_template] = resource_template
233+
self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id
207234

208235
async def remove_server(self, server_id: str) -> None:
209236
"""
@@ -488,7 +515,7 @@ async def _initialize_server_capabilities(self):
488515

489516
async def get_sse_server_app(
490517
self, allow_origins: t.Optional[t.List[str]] = None, include_lifespan: bool = True
491-
) -> AppType:
518+
) -> Starlette:
492519
"""
493520
Get the SSE server app.
494521
@@ -501,7 +528,9 @@ async def get_sse_server_app(
501528
"""
502529
await self.initialize_router()
503530

504-
sse = RouterSseTransport("/messages/")
531+
# Pass the API key to the RouterSseTransport
532+
api_key = None if not self.router_config.auth_enabled else self.router_config.api_key
533+
sse = RouterSseTransport("/messages/", api_key=api_key)
505534

506535
async def handle_sse(request: Request) -> None:
507536
async with sse.connect_sse(
@@ -515,11 +544,11 @@ async def handle_sse(request: Request) -> None:
515544
self.aggregated_server.initialization_options,
516545
)
517546

518-
lifespan_handler: t.Optional[Lifespan[AppType]] = None
547+
lifespan_handler: t.Optional[Lifespan[Starlette]] = None
519548
if include_lifespan:
520549

521550
@asynccontextmanager
522-
async def lifespan(app: AppType):
551+
async def lifespan(app: Starlette):
523552
yield
524553
await self.shutdown()
525554

0 commit comments

Comments
 (0)