From 02bb864d567142ecbdcadb2846c817b4b5f834be Mon Sep 17 00:00:00 2001 From: Roo Code Date: Mon, 21 Jul 2025 15:01:22 +0000 Subject: [PATCH] feat: respect mcpEnabled setting to prevent automatic MCP server initialization - Add conditional initialization in ClineProvider based on mcpEnabled setting - Update McpHub constructor to accept autoInitialize parameter - Add enableServers() and disableServers() methods for dynamic control - Handle mcpEnabled toggle in webview message handler - Update tests to cover new behavior Fixes #6014 --- src/core/webview/ClineProvider.ts | 33 +++-- src/core/webview/webviewMessageHandler.ts | 11 ++ src/services/mcp/McpHub.ts | 43 +++++- src/services/mcp/McpServerManager.ts | 8 +- src/services/mcp/__tests__/McpHub.spec.ts | 153 ++++++++++++++++++++++ 5 files changed, 234 insertions(+), 14 deletions(-) diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 6231f081670..666d3e92ae0 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -148,19 +148,34 @@ export class ClineProvider await this.postStateToWebview() }) - // Initialize MCP Hub through the singleton manager - McpServerManager.getInstance(this.context, this) - .then((hub) => { - this.mcpHub = hub - this.mcpHub.registerClient() - }) - .catch((error) => { - this.log(`Failed to initialize MCP Hub: ${error}`) - }) + // Initialize MCP Hub through the singleton manager only if mcpEnabled + this.initializeMcpIfEnabled() this.marketplaceManager = new MarketplaceManager(this.context) } + private async initializeMcpIfEnabled() { + try { + // Get the mcpEnabled setting from global state + const mcpEnabled = this.contextProxy.getValue("mcpEnabled") ?? true + + if (mcpEnabled) { + // Initialize MCP Hub through the singleton manager + const hub = await McpServerManager.getInstance(this.context, this, true) + this.mcpHub = hub + this.mcpHub.registerClient() + } else { + this.log("MCP is disabled, skipping MCP Hub initialization") + // Still create the hub instance but don't initialize servers + const hub = await McpServerManager.getInstance(this.context, this, false) + this.mcpHub = hub + this.mcpHub.registerClient() + } + } catch (error) { + this.log(`Failed to initialize MCP Hub: ${error}`) + } + } + // Adds a new Cline instance to clineStack, marking the start of a new task. // The instance is pushed to the top of the stack (LIFO order). // When the task is completed, the top instance is removed, reactivating the previous task. diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 780d40df891..8483d253b3b 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -880,6 +880,17 @@ export const webviewMessageHandler = async ( case "mcpEnabled": const mcpEnabled = message.bool ?? true await updateGlobalState("mcpEnabled", mcpEnabled) + + // Enable or disable MCP servers based on the setting + const mcpHubInstance = provider.getMcpHub() + if (mcpHubInstance) { + if (mcpEnabled) { + await mcpHubInstance.enableServers() + } else { + await mcpHubInstance.disableServers() + } + } + await provider.postStateToWebview() break case "enableMcpServerCreation": diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 10a74712ef0..7bd40945da8 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -135,13 +135,16 @@ export class McpHub { private refCount: number = 0 // Reference counter for active clients private configChangeDebounceTimers: Map = new Map() - constructor(provider: ClineProvider) { + constructor(provider: ClineProvider, autoInitialize: boolean = true) { this.providerRef = new WeakRef(provider) this.watchMcpSettingsFile() this.watchProjectMcpFile().catch(console.error) this.setupWorkspaceFoldersWatcher() - this.initializeGlobalMcpServers() - this.initializeProjectMcpServers() + + if (autoInitialize) { + this.initializeGlobalMcpServers() + this.initializeProjectMcpServers() + } } /** * Registers a client (e.g., ClineProvider) using this hub. @@ -1643,4 +1646,38 @@ export class McpHub { } this.disposables.forEach((d) => d.dispose()) } + + /** + * Enable MCP servers by initializing them + */ + async enableServers(): Promise { + if (this.connections.length > 0) { + console.log("MCP servers are already initialized") + return + } + + console.log("Enabling MCP servers...") + await this.initializeGlobalMcpServers() + await this.initializeProjectMcpServers() + await this.notifyWebviewOfServerChanges() + } + + /** + * Disable MCP servers by disconnecting all connections + */ + async disableServers(): Promise { + console.log("Disabling MCP servers...") + + // Disconnect all servers + const allConnections = [...this.connections] + for (const conn of allConnections) { + await this.deleteConnection(conn.server.name, conn.server.source) + } + + // Clear connections array + this.connections = [] + + // Notify webview of changes + await this.notifyWebviewOfServerChanges() + } } diff --git a/src/services/mcp/McpServerManager.ts b/src/services/mcp/McpServerManager.ts index e15f9db0a7a..8bb72baa101 100644 --- a/src/services/mcp/McpServerManager.ts +++ b/src/services/mcp/McpServerManager.ts @@ -17,7 +17,11 @@ export class McpServerManager { * Creates a new instance if one doesn't exist. * Thread-safe implementation using a promise-based lock. */ - static async getInstance(context: vscode.ExtensionContext, provider: ClineProvider): Promise { + static async getInstance( + context: vscode.ExtensionContext, + provider: ClineProvider, + autoInitialize: boolean = true, + ): Promise { // Register the provider this.providers.add(provider) @@ -36,7 +40,7 @@ export class McpServerManager { try { // Double-check instance in case it was created while we were waiting if (!this.instance) { - this.instance = new McpHub(provider) + this.instance = new McpHub(provider, autoInitialize) // Store a unique identifier in global state to track the primary instance await context.globalState.update(this.GLOBAL_STATE_KEY, Date.now().toString()) } diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index 7dc7f00c045..cc05e969f25 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -65,6 +65,19 @@ vi.mock("vscode", () => ({ Disposable: { from: vi.fn(), }, + Uri: { + file: vi.fn((path) => ({ + scheme: "file", + authority: "", + path, + query: "", + fragment: "", + fsPath: path, + with: vi.fn(), + toJSON: vi.fn(), + })), + }, + RelativePattern: vi.fn((base, pattern) => ({ base, pattern })), })) vi.mock("fs/promises") vi.mock("../../../core/webview/ClineProvider") @@ -1226,4 +1239,144 @@ describe("McpHub", () => { ) }) }) + + describe("Conditional initialization", () => { + it("should not initialize servers when autoInitialize is false", async () => { + // Clear existing mocks + vi.clearAllMocks() + + // Create McpHub with autoInitialize = false + const mcpHubNoInit = new McpHub(mockProvider as ClineProvider, false) + + // Verify no connections were created + expect(mcpHubNoInit.connections.length).toBe(0) + + // Verify fs.readFile was not called for server initialization + expect(fs.readFile).not.toHaveBeenCalled() + }) + + it("should initialize servers when autoInitialize is true", async () => { + // Clear existing mocks + vi.clearAllMocks() + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create McpHub with autoInitialize = true + const mcpHubInit = new McpHub(mockProvider as ClineProvider, true) + + // Wait a bit for async initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify fs.readFile was called for initialization + expect(fs.readFile).toHaveBeenCalled() + }) + + it("should enable servers after creation when enableServers is called", async () => { + // Clear existing mocks + vi.clearAllMocks() + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Mock StdioClientTransport + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + Client.mockImplementation(() => ({ + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + })) + + // Create McpHub without auto-initialization + const mcpHubNoInit = new McpHub(mockProvider as ClineProvider, false) + + // Verify no connections initially + expect(mcpHubNoInit.connections.length).toBe(0) + + // Enable servers + await mcpHubNoInit.enableServers() + + // Verify connections were created + expect(mcpHubNoInit.connections.length).toBeGreaterThan(0) + }) + + it("should disable all servers when disableServers is called", async () => { + // Create McpHub with some connections + mcpHub.connections = [ + { + server: { + name: "test-server-1", + source: "global", + } as any, + client: { + close: vi.fn().mockResolvedValue(undefined), + } as any, + transport: { + close: vi.fn().mockResolvedValue(undefined), + } as any, + }, + { + server: { + name: "test-server-2", + source: "project", + } as any, + client: { + close: vi.fn().mockResolvedValue(undefined), + } as any, + transport: { + close: vi.fn().mockResolvedValue(undefined), + } as any, + }, + ] + + // Disable servers + await mcpHub.disableServers() + + // Verify all connections were closed + expect(mcpHub.connections.length).toBe(0) + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith( + expect.objectContaining({ + type: "mcpServers", + mcpServers: [], + }), + ) + }) + }) })