diff --git a/.github/workflows/semantic-check.yml b/.github/workflows/semantic-check.yml index 859ad45f..4ad4bece 100644 --- a/.github/workflows/semantic-check.yml +++ b/.github/workflows/semantic-check.yml @@ -1,23 +1,23 @@ name: Semantic Version Check on: - pull_request: + pull_request_target: types: [opened, synchronize, reopened] -permissions: - contents: read - pull-requests: write - issues: write - jobs: semver-check: name: Validate Semantic Version runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write + issues: write steps: - name: Checkout uses: actions/checkout@v4 with: + ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 0 persist-credentials: false @@ -41,38 +41,24 @@ jobs: id: semantic with: dry_run: true - ci: false + ci: true extra_plugins: | @semantic-release/commit-analyzer @semantic-release/release-notes-generator - branches: | - [ - "${GITHUB_HEAD_REF}" - ] env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Comment PR - if: always() - uses: actions/github-script@v7 + uses: thollander/actions-comment-pull-request@v2 with: - script: | - let comment = '## Semantic Version Check\n\n'; + message: | + ## Semantic Version Check - if ('${{ steps.semantic.outputs.new_release_version }}') { - comment += `✅ Valid semantic version changes detected!\n\n`; - comment += `Next version will be: **${{ steps.semantic.outputs.new_release_version }}**\n`; - } else { - comment += `⚠️ No semantic version changes detected.\n\n`; - comment += 'Please ensure your commits follow the [Conventional Commits](https://www.conventionalcommits.org/) format:\n\n'; - comment += '- `feat: new feature` (triggers MINOR version bump)\n'; - comment += '- `fix: bug fix` (triggers PATCH version bump)\n'; - comment += '- `BREAKING CHANGE: description` (triggers MAJOR version bump)\n'; - } + ${{ steps.semantic.outputs.new_release_version && '✅ Valid semantic version changes detected!' || '⚠️ No semantic version changes detected.' }} - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: comment - }); \ No newline at end of file + ${{ steps.semantic.outputs.new_release_version && format('Next version will be: **{0}**', steps.semantic.outputs.new_release_version) || 'Please ensure your commits follow the [Conventional Commits](https://www.conventionalcommits.org/) format: + + - `feat: new feature` (triggers MINOR version bump) + - `fix: bug fix` (triggers PATCH version bump) + - `BREAKING CHANGE: description` (triggers MAJOR version bump)' }} + comment_tag: semantic-version-check \ No newline at end of file diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index b6c7b947..8298ef45 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -6,7 +6,7 @@ import typing as t from collections import defaultdict from contextlib import asynccontextmanager -from typing import Optional +from typing import Literal, Optional import uvicorn from mcp import server, types @@ -17,7 +17,7 @@ from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.routing import Mount, Route -from starlette.types import AppType +from starlette.types import AppType, Lifespan from mcpm.monitor.base import AccessEventType from mcpm.monitor.event import trace_event @@ -38,19 +38,28 @@ class MCPRouter: exposes them as a single SSE server. """ - def __init__(self, reload_server: bool = False) -> None: - """Initialize the router.""" + def __init__(self, reload_server: bool = False, profile_path: str | None = None, strict: bool = False) -> None: + """ + Initialize the router. + + :param reload_server: Whether to reload the server when the config changes + :param profile_path: Path to the profile file + :param strict: Whether to use strict mode for duplicated tool name. + If True, raise error when duplicated tool name is found else auto resolve by adding server name prefix + """ self.server_sessions: t.Dict[str, ServerConnection] = {} self.capabilities_mapping: t.Dict[str, t.Dict[str, t.Any]] = defaultdict(dict) - self.tools_mapping: t.Dict[str, t.Dict[str, t.Any]] = {} - self.prompts_mapping: t.Dict[str, t.Dict[str, t.Any]] = {} - self.resources_mapping: t.Dict[str, t.Dict[str, t.Any]] = {} - self.resources_templates_mapping: t.Dict[str, t.Dict[str, t.Any]] = {} + self.capabilities_to_server_id: t.Dict[str, t.Dict[str, t.Any]] = defaultdict(dict) + self.tools_mapping: t.Dict[str, types.Tool] = {} + self.prompts_mapping: t.Dict[str, types.Prompt] = {} + self.resources_mapping: t.Dict[str, types.Resource] = {} + self.resources_templates_mapping: t.Dict[str, types.ResourceTemplate] = {} self.aggregated_server = self._create_aggregated_server() - self.profile_manager = ProfileConfigManager() + self.profile_manager = ProfileConfigManager() if profile_path is None else ProfileConfigManager(profile_path) self.watcher: Optional[ConfigWatcher] = None if reload_server: self.watcher = ConfigWatcher(self.profile_manager.profile_path) + self.strict: bool = strict def get_unique_servers(self) -> list[ServerConfig]: profiles = self.profile_manager.list_profiles() @@ -120,35 +129,77 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None: # Collect server tools, prompts, and resources if response.capabilities.tools: tools = await client.session.list_tools() # type: ignore - # Add tools with namespaced names, preserving existing tools - self.tools_mapping.update( - {f"{server_id}{TOOL_SPLITOR}{tool.name}": tool.model_dump() for tool in tools.tools} - ) + for tool in tools.tools: + # To make sure tool name is unique across all servers + tool_name = tool.name + if tool_name in self.capabilities_to_server_id["tools"]: + if self.strict: + raise ValueError( + f"Tool {tool_name} already exists. Please use unique tool names across all servers." + ) + else: + # Auto resolve by adding server name prefix + tool_name = f"{server_id}{TOOL_SPLITOR}{tool_name}" + self.capabilities_to_server_id["tools"][tool_name] = server_id + self.tools_mapping[tool_name] = tool if response.capabilities.prompts: prompts = await client.session.list_prompts() # type: ignore - # Add prompts with namespaced names, preserving existing prompts - self.prompts_mapping.update( - {f"{server_id}{PROMPT_SPLITOR}{prompt.name}": prompt.model_dump() for prompt in prompts.prompts} - ) + for prompt in prompts.prompts: + # To make sure prompt name is unique across all servers + prompt_name = prompt.name + if prompt_name in self.capabilities_to_server_id["prompts"]: + if self.strict: + raise ValueError( + f"Prompt {prompt_name} already exists. Please use unique prompt names across all servers." + ) + else: + # Auto resolve by adding server name prefix + prompt_name = f"{server_id}{PROMPT_SPLITOR}{prompt_name}" + self.prompts_mapping[prompt_name] = prompt + self.capabilities_to_server_id["prompts"][prompt_name] = server_id if response.capabilities.resources: resources = await client.session.list_resources() # type: ignore - # Add resources with namespaced URIs, preserving existing resources - self.resources_mapping.update( - { - f"{server_id}{RESOURCE_SPLITOR}{resource.uri}": resource.model_dump() - for resource in resources.resources - } - ) + for resource in resources.resources: + # To make sure resource URI is unique across all servers + resource_uri = resource.uri + if str(resource_uri) in self.capabilities_to_server_id["resources"]: + if self.strict: + raise ValueError( + f"Resource {resource_uri} already exists. Please use unique resource URIs across all servers." + ) + else: + # Auto resolve by adding server name prefix + host = resource_uri.host + resource_uri = AnyUrl.build( + host=f"{server_id}{RESOURCE_SPLITOR}{host}", + scheme=resource_uri.scheme, + path=resource_uri.path, + username=resource_uri.username, + password=resource_uri.password, + port=resource_uri.port, + query=resource_uri.query, + fragment=resource_uri.fragment, + ) + self.resources_mapping[str(resource_uri)] = resource + self.capabilities_to_server_id["resources"][str(resource_uri)] = server_id resources_templates = await client.session.list_resource_templates() # type: ignore - # Add resource templates with namespaced URIs, preserving existing templates - self.resources_templates_mapping.update( - { - f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}": resource_template.model_dump() - for resource_template in resources_templates.resourceTemplates - } - ) + for resource_template in resources_templates.resourceTemplates: + # To make sure resource template URI is unique across all servers + resource_template_uri_template = resource_template.uriTemplate + if resource_template_uri_template in self.capabilities_to_server_id["resource_templates"]: + if self.strict: + raise ValueError( + f"Resource template {resource_template_uri_template} already exists. Please use unique resource template URIs across all servers." + ) + else: + # Auto resolve by adding server name prefix + resource_template_uri_template = ( + f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}{resource_template.uriTemplate}" + ) + self.resources_templates_mapping[resource_template_uri_template] = resource_template + self.capabilities_to_server_id["resource_templates"][resource_template_uri_template] = server_id async def remove_server(self, server_id: str) -> None: """ @@ -170,28 +221,32 @@ async def remove_server(self, server_id: str) -> None: # Delete registered tools, resources and prompts for key in list(self.tools_mapping.keys()): - if key.startswith(f"{server_id}{TOOL_SPLITOR}"): + if self.capabilities_to_server_id["tools"].get(key) == server_id: self.tools_mapping.pop(key) + self.capabilities_to_server_id["tools"].pop(key) for key in list(self.prompts_mapping.keys()): - if key.startswith(f"{server_id}{PROMPT_SPLITOR}"): + if self.capabilities_to_server_id["prompts"].get(key) == server_id: self.prompts_mapping.pop(key) + self.capabilities_to_server_id["prompts"].pop(key) for key in list(self.resources_mapping.keys()): - if key.startswith(f"{server_id}{RESOURCE_SPLITOR}"): + if self.capabilities_to_server_id["resources"].get(key) == server_id: self.resources_mapping.pop(key) + self.capabilities_to_server_id["resources"].pop(key) for key in list(self.resources_templates_mapping.keys()): - if key.startswith(f"{server_id}{RESOURCE_TEMPLATE_SPLITOR}"): + if self.capabilities_to_server_id["resource_templates"].get(key) == server_id: self.resources_templates_mapping.pop(key) + self.capabilities_to_server_id["resource_templates"].pop(key) def _patch_handler_func(self, app: server.Server) -> server.Server: def get_active_servers(profile: str) -> list[str]: servers = self.profile_manager.get_profile(profile) or [] return [server.name for server in servers] - def parse_namespaced_id(id_value, splitor): - """Parse namespaced ID, return server ID and original ID.""" - if splitor in str(id_value): - return str(id_value).split(splitor, 1) - return None, None + def get_capability_server_id( + capability_type: Literal["tools", "prompts", "resources", "resource_templates"], id_value: str + ) -> str | None: + """Get the server ID associated with a capability ID.""" + return self.capabilities_to_server_id[capability_type].get(id_value) def empty_result() -> types.ServerResult: return types.ServerResult(types.EmptyResult()) @@ -200,67 +255,76 @@ async def list_prompts(req: types.ListPromptsRequest) -> types.ServerResult: prompts: list[types.Prompt] = [] active_servers = get_active_servers(req.params.meta.profile) # type: ignore for server_prompt_id, prompt in self.prompts_mapping.items(): - server_id, _ = parse_namespaced_id(server_prompt_id, PROMPT_SPLITOR) + server_id = get_capability_server_id("prompts", server_prompt_id) if server_id in active_servers: - prompt.update({"name": server_prompt_id}) - prompts.append(types.Prompt(**prompt)) + prompts.append(prompt.model_copy(update={"name": server_prompt_id})) return types.ServerResult(types.ListPromptsResult(prompts=prompts)) @trace_event(AccessEventType.PROMPT_EXECUTION) async def get_prompt(req: types.GetPromptRequest) -> types.ServerResult: active_servers = get_active_servers(req.params.meta.profile) # type: ignore - server_id, prompt_name = parse_namespaced_id(req.params.name, PROMPT_SPLITOR) - if server_id is None or prompt_name is None: + server_id = get_capability_server_id("prompts", req.params.name) + if server_id is None: return empty_result() if server_id not in active_servers: return empty_result() - - result = await self.server_sessions[server_id].session.get_prompt(prompt_name, req.params.arguments) + prompt = self.prompts_mapping.get(req.params.name) + if prompt is None: + return empty_result() + result = await self.server_sessions[server_id].session.get_prompt(prompt.name, req.params.arguments) return types.ServerResult(result) async def list_resources(req: types.ListResourcesRequest) -> types.ServerResult: resources: list[types.Resource] = [] active_servers = get_active_servers(req.params.meta.profile) # type: ignore for server_resource_id, resource in self.resources_mapping.items(): - server_id, _ = parse_namespaced_id(server_resource_id, RESOURCE_SPLITOR) + server_id = get_capability_server_id("resources", server_resource_id) + if server_id is None: + continue if server_id in active_servers: - resource.update({"uri": server_resource_id}) - resources.append(types.Resource(**resource)) + resources.append(resource.model_copy(update={"uri": AnyUrl(server_resource_id)})) return types.ServerResult(types.ListResourcesResult(resources=resources)) async def list_resource_templates(req: types.ListResourceTemplatesRequest) -> types.ServerResult: resource_templates: list[types.ResourceTemplate] = [] active_servers = get_active_servers(req.params.meta.profile) # type: ignore for server_resource_template_id, resource_template in self.resources_templates_mapping.items(): - server_id, _ = parse_namespaced_id(server_resource_template_id, RESOURCE_TEMPLATE_SPLITOR) + server_id = get_capability_server_id("resource_templates", server_resource_template_id) + if server_id is None: + continue if server_id in active_servers: - resource_template.update({"uriTemplate": server_resource_template_id}) - resource_templates.append(types.ResourceTemplate(**resource_template)) + resource_templates.append( + resource_template.model_copy(update={"uriTemplate": server_resource_template_id}) + ) return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=resource_templates)) @trace_event(AccessEventType.RESOURCE_ACCESS) async def read_resource(req: types.ReadResourceRequest) -> types.ServerResult: active_servers = get_active_servers(req.params.meta.profile) # type: ignore - server_id, resource_uri = parse_namespaced_id(req.params.uri, RESOURCE_SPLITOR) - if server_id is None or resource_uri is None: + server_id = get_capability_server_id("resources", str(req.params.uri)) + if server_id is None: return empty_result() if server_id not in active_servers: return empty_result() + resource = self.resources_mapping.get(str(req.params.uri)) + if resource is None: + return empty_result() - result = await self.server_sessions[server_id].session.read_resource(AnyUrl(resource_uri)) + result = await self.server_sessions[server_id].session.read_resource(resource.uri) return types.ServerResult(result) async def list_tools(req: types.ListToolsRequest) -> types.ServerResult: tools: list[types.Tool] = [] active_servers = get_active_servers(req.params.meta.profile) # type: ignore for server_tool_id, tool in self.tools_mapping.items(): - server_id, _ = parse_namespaced_id(server_tool_id, TOOL_SPLITOR) + server_id = get_capability_server_id("tools", server_tool_id) + if server_id is None: + continue if server_id in active_servers: - tool.update({"name": server_tool_id}) - tools.append(types.Tool(**tool)) + tools.append(tool.model_copy(update={"name": server_tool_id})) if not tools: return empty_result() @@ -272,16 +336,25 @@ async def call_tool(req: types.CallToolRequest) -> types.ServerResult: active_servers = get_active_servers(req.params.meta.profile) # type: ignore logger.info(f"call_tool: {req} with active servers: {active_servers}") - server_id, tool_name = parse_namespaced_id(req.params.name, TOOL_SPLITOR) - if server_id is None or tool_name is None: + tool_name = req.params.name + server_id = get_capability_server_id("tools", tool_name) + if server_id is None: + logger.debug(f"call_tool: {req} with tool_name: {tool_name}. Server ID {server_id} is not found") return empty_result() if server_id not in active_servers: + logger.debug( + f"call_tool: {req} with tool_name: {tool_name}. Server ID {server_id} is not in active servers" + ) + return empty_result() + tool = self.tools_mapping.get(tool_name) + if tool is None: return empty_result() try: - result = await self.server_sessions[server_id].session.call_tool(tool_name, req.params.arguments or {}) + result = await self.server_sessions[server_id].session.call_tool(tool.name, req.params.arguments or {}) return types.ServerResult(result) except Exception as e: + logger.error(f"Error calling tool {tool_name} on server {server_id}: {e}") return types.ServerResult( types.CallToolResult( content=[types.TextContent(type="text", text=str(e))], @@ -293,20 +366,28 @@ async def complete(req: types.CompleteRequest) -> types.ServerResult: active_servers = get_active_servers(req.params.meta.profile) # type: ignore if isinstance(req.params.ref, types.PromptReference): - server_id, prompt_name = parse_namespaced_id(req.params.ref.name, PROMPT_SPLITOR) - if server_id is None or prompt_name is None: + server_id = get_capability_server_id("prompts", req.params.ref.name) + if server_id is None: + return empty_result() + if server_id not in active_servers: return empty_result() - ref = types.PromptReference(name=prompt_name, type="ref/prompt") + prompt = self.prompts_mapping.get(req.params.ref.name) + if prompt is None: + return empty_result() + ref = types.PromptReference(name=prompt.name, type="ref/prompt") elif isinstance(req.params.ref, types.ResourceReference): - server_id, resource_uri = parse_namespaced_id(req.params.ref.uri, RESOURCE_SPLITOR) - if server_id is None or resource_uri is None: + server_id = get_capability_server_id("resources", str(req.params.ref.uri)) + if server_id is None: + return empty_result() + resource = self.resources_mapping.get(str(req.params.ref.uri)) + if resource is None: return empty_result() - ref = types.ResourceReference(uri=resource_uri, type="ref/resource") + ref = types.ResourceReference(uri=str(resource.uri), type="ref/resource") if server_id not in active_servers: return empty_result() - result = await self.server_sessions[server_id].session.complete(ref, req.params.argument.model_dump()) + result = await self.server_sessions[server_id].session.complete(ref, req.params.arguments or {}) return types.ServerResult(result) app.request_handlers[types.ListPromptsRequest] = list_prompts @@ -400,24 +481,23 @@ async def _initialize_server_capabilities(self): capabilities=capabilities, ) - async def start_sse_server( - self, host: str = "localhost", port: int = 8080, allow_origins: t.Optional[t.List[str]] = None - ) -> None: + async def get_sse_server_app( + self, allow_origins: t.Optional[t.List[str]] = None, include_lifespan: bool = True + ) -> AppType: """ - Start an SSE server that exposes the aggregated MCP server. + Get the SSE server app. Args: - host: The host to bind to - port: The port to bind to allow_origins: List of allowed origins for CORS + include_lifespan: Whether to include the router's lifespan manager in the app. + + Returns: + An SSE server app """ - # waiting all servers to be initialized await self.initialize_router() - # Create SSE transport sse = RouterSseTransport("/messages/") - # Handle SSE connections async def handle_sse(request: Request) -> None: async with sse.connect_sse( request.scope, @@ -430,12 +510,16 @@ async def handle_sse(request: Request) -> None: self.aggregated_server.initialization_options, ) - @asynccontextmanager - async def lifespan(app: AppType): - yield - await self.shutdown() + lifespan_handler: t.Optional[Lifespan[AppType]] = None + if include_lifespan: + + @asynccontextmanager + async def lifespan(app: AppType): + yield + await self.shutdown() + + lifespan_handler = lifespan - # Set up middleware for CORS if needed middleware: t.List[Middleware] = [] if allow_origins is not None: middleware.append( @@ -447,7 +531,6 @@ async def lifespan(app: AppType): ), ) - # Create Starlette app app = Starlette( debug=False, middleware=middleware, @@ -455,8 +538,22 @@ async def lifespan(app: AppType): Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], - lifespan=lifespan, + lifespan=lifespan_handler, ) + return app + + async def start_sse_server( + self, host: str = "localhost", port: int = 8080, allow_origins: t.Optional[t.List[str]] = None + ) -> None: + """ + Start an SSE server that exposes the aggregated MCP server. + + Args: + host: The host to bind to + port: The port to bind to + allow_origins: List of allowed origins for CORS + """ + app = await self.get_sse_server_app(allow_origins) # Configure and start the server config = uvicorn.Config( diff --git a/src/mcpm/router/transport.py b/src/mcpm/router/transport.py index 7b151c82..d771e2f0 100644 --- a/src/mcpm/router/transport.py +++ b/src/mcpm/router/transport.py @@ -99,11 +99,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # maintain session_id to identifier mapping profile = get_key_from_scope(scope, key_name="profile") client_id = get_key_from_scope(scope, key_name="client") - if profile is not None: - self._session_id_to_identifier[session_id] = ClientIdentifier( - client_id=client_id or "anonymous", profile=profile, api_key=api_key - ) - logger.debug(f"Session {session_id} mapped to identifier {self._session_id_to_identifier[session_id]}") + logger.debug(f"Profile: {profile}, Client ID: {client_id}") + client_id = client_id or "anonymous" + profile = profile or "default" + self._session_id_to_identifier[session_id] = ClientIdentifier( + client_id=client_id, profile=profile, api_key=api_key + ) + logger.debug(f"Session {session_id} mapped to identifier {self._session_id_to_identifier[session_id]}") sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) diff --git a/tests/test_add.py b/tests/test_add.py index 8b6d6177..9c30c170 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -9,9 +9,9 @@ def test_add_server(windsurf_manager, monkeypatch): """Test add server""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr( RepositoryManager, "_fetch_servers", @@ -51,9 +51,9 @@ def test_add_server(windsurf_manager, monkeypatch): def test_add_server_with_missing_arg(windsurf_manager, monkeypatch): """Test add server with a missing argument that should be replaced with empty string""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr( RepositoryManager, "_fetch_servers", @@ -118,6 +118,7 @@ def test_add_server_with_empty_args(windsurf_manager, monkeypatch): monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr( RepositoryManager, "_fetch_servers", diff --git a/tests/test_remove.py b/tests/test_remove.py index 27716a02..151e2816 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -9,10 +9,10 @@ def test_remove_server_success(windsurf_manager, monkeypatch): """Test successful server removal""" # Setup mocks - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) # Mock server info mock_server = Mock() @@ -33,10 +33,10 @@ def test_remove_server_success(windsurf_manager, monkeypatch): def test_remove_server_not_found(windsurf_manager, monkeypatch): """Test removal of non-existent server""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) # Mock server not found windsurf_manager.get_server = Mock(return_value=None) @@ -51,7 +51,7 @@ def test_remove_server_not_found(windsurf_manager, monkeypatch): def test_remove_server_unsupported_client(monkeypatch): """Test removal with unsupported client""" monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=None)) - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="unsupported")) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@unsupported")) runner = CliRunner() result = runner.invoke(remove, ["server-test"]) @@ -62,10 +62,10 @@ def test_remove_server_unsupported_client(monkeypatch): def test_remove_server_cancelled(windsurf_manager, monkeypatch): """Test removal when user cancels the confirmation""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) # Mock server info mock_server = Mock() @@ -87,10 +87,10 @@ def test_remove_server_cancelled(windsurf_manager, monkeypatch): def test_remove_server_failure(windsurf_manager, monkeypatch): """Test removal when the removal operation fails""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) # Mock server info mock_server = Mock() diff --git a/tests/test_stash_pop.py b/tests/test_stash_pop.py index d42e9d07..b396b460 100644 --- a/tests/test_stash_pop.py +++ b/tests/test_stash_pop.py @@ -10,7 +10,7 @@ def test_stash_server_success(windsurf_manager, monkeypatch): """Test successful server stashing""" # Setup mocks - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) @@ -49,7 +49,7 @@ def test_stash_server_success(windsurf_manager, monkeypatch): def test_stash_server_already_stashed(windsurf_manager, monkeypatch): """Test stashing an already stashed server""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) @@ -74,7 +74,7 @@ def test_stash_server_already_stashed(windsurf_manager, monkeypatch): def test_stash_server_remove_failure(windsurf_manager, monkeypatch): """Test stashing when server removal fails""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) @@ -103,7 +103,7 @@ def test_stash_server_remove_failure(windsurf_manager, monkeypatch): def test_stash_server_not_found(windsurf_manager, monkeypatch): """Test stashing a non-existent server""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) @@ -126,8 +126,8 @@ def test_stash_server_not_found(windsurf_manager, monkeypatch): def test_stash_server_unsupported_client(monkeypatch): """Test stashing with unsupported client""" + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@unsupported")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=None)) - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="unsupported")) # Mock client config manager mock_config_manager = Mock() @@ -143,7 +143,7 @@ def test_stash_server_unsupported_client(monkeypatch): def test_pop_server_success(windsurf_manager, monkeypatch): """Test successful server restoration""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) @@ -178,7 +178,7 @@ def test_pop_server_success(windsurf_manager, monkeypatch): def test_pop_server_not_stashed(windsurf_manager, monkeypatch): """Test popping a non-stashed server""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) @@ -198,7 +198,7 @@ def test_pop_server_not_stashed(windsurf_manager, monkeypatch): def test_pop_server_add_failure(windsurf_manager, monkeypatch): """Test popping when server addition fails""" - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="windsurf")) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@windsurf")) monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=windsurf_manager)) monkeypatch.setattr(ClientRegistry, "get_client_info", Mock(return_value={"name": "windsurf"})) monkeypatch.setattr(ClientRegistry, "get_client_manager", Mock(return_value=windsurf_manager)) @@ -233,7 +233,7 @@ def test_pop_server_add_failure(windsurf_manager, monkeypatch): def test_pop_server_unsupported_client(monkeypatch): """Test popping with unsupported client""" monkeypatch.setattr(ClientRegistry, "get_active_client_manager", Mock(return_value=None)) - monkeypatch.setattr(ClientRegistry, "get_active_client", Mock(return_value="unsupported")) + monkeypatch.setattr(ClientRegistry, "determine_active_scope", Mock(return_value="@unsupported")) # Mock client config manager mock_config_manager = Mock()