Skip to content

Commit 3b1893c

Browse files
committed
Add tests for router and profile
1 parent 9e16baa commit 3b1893c

File tree

4 files changed

+780
-69
lines changed

4 files changed

+780
-69
lines changed

src/mcpm/router/router.py

Lines changed: 77 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
profile_path: str | None = None,
7171
strict: bool = False,
7272
api_key: str | None = None,
73-
router_config: dict | None = None
73+
router_config: dict | None = None,
7474
) -> None:
7575
"""
7676
Initialize the router.
@@ -184,77 +184,89 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
184184
# Collect server tools, prompts, and resources
185185
if response.capabilities.tools:
186186
tools = await client.session.list_tools() # type: ignore
187-
for tool in tools.tools:
188-
# To make sure tool name is unique across all servers
189-
tool_name = tool.name
190-
if tool_name in self.capabilities_to_server_id["tools"]:
191-
if self.strict:
192-
raise ValueError(
193-
f"Tool {tool_name} already exists. Please use unique tool names across all servers."
194-
)
195-
else:
196-
# Auto resolve by adding server name prefix
197-
tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}"
198-
self.capabilities_to_server_id["tools"][tool_name] = server_id
199-
self.tools_mapping[tool_name] = tool
187+
# Extract ListToolsResult from ServerResult
188+
tools_result = tools.root
189+
if isinstance(tools_result, types.ListToolsResult):
190+
for tool in tools_result.tools:
191+
# To make sure tool name is unique across all servers
192+
tool_name = tool.name
193+
if tool_name in self.capabilities_to_server_id["tools"]:
194+
if self.strict:
195+
raise ValueError(
196+
f"Tool {tool_name} already exists. Please use unique tool names across all servers."
197+
)
198+
else:
199+
# Auto resolve by adding server name prefix
200+
tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}"
201+
self.capabilities_to_server_id["tools"][tool_name] = server_id
202+
self.tools_mapping[tool_name] = tool
200203

201204
if response.capabilities.prompts:
202205
prompts = await client.session.list_prompts() # type: ignore
203-
for prompt in prompts.prompts:
204-
# To make sure prompt name is unique across all servers
205-
prompt_name = prompt.name
206-
if prompt_name in self.capabilities_to_server_id["prompts"]:
207-
if self.strict:
208-
raise ValueError(
209-
f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers."
210-
)
211-
else:
212-
# Auto resolve by adding server name prefix
213-
prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}"
214-
self.prompts_mapping[prompt_name] = prompt
215-
self.capabilities_to_server_id["prompts"][prompt_name] = server_id
206+
# Extract ListPromptsResult from ServerResult
207+
prompts_result = prompts.root
208+
if isinstance(prompts_result, types.ListPromptsResult):
209+
for prompt in prompts_result.prompts:
210+
# To make sure prompt name is unique across all servers
211+
prompt_name = prompt.name
212+
if prompt_name in self.capabilities_to_server_id["prompts"]:
213+
if self.strict:
214+
raise ValueError(
215+
f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers."
216+
)
217+
else:
218+
# Auto resolve by adding server name prefix
219+
prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}"
220+
self.prompts_mapping[prompt_name] = prompt
221+
self.capabilities_to_server_id["prompts"][prompt_name] = server_id
216222

217223
if response.capabilities.resources:
218224
resources = await client.session.list_resources() # type: ignore
219-
for resource in resources.resources:
220-
# To make sure resource URI is unique across all servers
221-
resource_uri = resource.uri
222-
if str(resource_uri) in self.capabilities_to_server_id["resources"]:
223-
if self.strict:
224-
raise ValueError(
225-
f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers."
226-
)
227-
else:
228-
# Auto resolve by adding server name prefix
229-
host = resource_uri.host
230-
resource_uri = AnyUrl.build(
231-
host=f"{server_id}{RESOURCE_SPLITOR}{host}",
232-
scheme=resource_uri.scheme,
233-
path=resource_uri.path,
234-
username=resource_uri.username,
235-
password=resource_uri.password,
236-
port=resource_uri.port,
237-
query=resource_uri.query,
238-
fragment=resource_uri.fragment,
239-
)
240-
self.resources_mapping[str(resource_uri)] = resource
241-
self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id
225+
# Extract ListResourcesResult from ServerResult
226+
resources_result = resources.root
227+
if isinstance(resources_result, types.ListResourcesResult):
228+
for resource in resources_result.resources:
229+
# To make sure resource URI is unique across all servers
230+
resource_uri = resource.uri
231+
if str(resource_uri) in self.capabilities_to_server_id["resources"]:
232+
if self.strict:
233+
raise ValueError(
234+
f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers."
235+
)
236+
else:
237+
# Auto resolve by adding server name prefix
238+
host = resource_uri.host
239+
resource_uri = AnyUrl.build(
240+
host=f"{server_id}{RESOURCE_SPLITOR}{host}",
241+
scheme=resource_uri.scheme,
242+
path=resource_uri.path,
243+
username=resource_uri.username,
244+
password=resource_uri.password,
245+
port=resource_uri.port,
246+
query=resource_uri.query,
247+
fragment=resource_uri.fragment,
248+
)
249+
self.resources_mapping[str(resource_uri)] = resource
250+
self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id
242251
resources_templates = await client.session.list_resource_templates() # type: ignore
243-
for resource_template in resources_templates.resourceTemplates:
244-
# To make sure resource template URI is unique across all servers
245-
resource_template_uri_template = resource_template.uriTemplate
246-
if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]:
247-
if self.strict:
248-
raise ValueError(
249-
f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers."
250-
)
251-
else:
252-
# Auto resolve by adding server name prefix
253-
resource_template_uri_template = (
254-
f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}"
255-
)
256-
self.resources_templates_mapping[resource_template_uri_template] = resource_template
257-
self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id
252+
# Extract ListResourceTemplatesResult from ServerResult
253+
templates_result = resources_templates.root
254+
if isinstance(templates_result, types.ListResourceTemplatesResult):
255+
for resource_template in templates_result.resourceTemplates:
256+
# To make sure resource template URI is unique across all servers
257+
resource_template_uri_template = resource_template.uriTemplate
258+
if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]:
259+
if self.strict:
260+
raise ValueError(
261+
f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers."
262+
)
263+
else:
264+
# Auto resolve by adding server name prefix
265+
resource_template_uri_template = (
266+
f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}"
267+
)
268+
self.resources_templates_mapping[resource_template_uri_template] = resource_template
269+
self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id
258270

259271
async def remove_server(self, server_id: str) -> None:
260272
"""

src/mcpm/router/transport.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,15 @@ def _validate_api_key(self, scope: Scope, api_key: str | None) -> bool:
243243
if self.api_key is None:
244244
logger.debug("API key validation disabled")
245245
return True
246-
247-
# If we have a directly provided API key and it matches the request's API key, return True
248-
if self.api_key is not None and api_key == self.api_key:
246+
247+
# If we have a directly provided API key, verify it matches
248+
if self.api_key is not None:
249+
# If API key doesn't match, return False
250+
if api_key != self.api_key:
251+
logger.warning("Unauthorized API key")
252+
return False
249253
return True
250-
254+
251255
# Otherwise, fall back to the original validation logic
252256
try:
253257
config_manager = ConfigManager()

0 commit comments

Comments
 (0)