From 6155283ff03b9345bb399a5c0e82eedd16fa8f32 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Fri, 22 Aug 2025 00:22:10 +0000 Subject: [PATCH 1/2] fix: prevent MCP server restarts during active tool executions - Add tracking of active tool executions in McpHub - Prevent server restarts when tools are running - Update toggleToolAlwaysAllow to skip restart during tool execution - Update toggleToolEnabledForPrompt to skip restart during tool execution - Prevent toggleServerDisabled when tools are running - Add comprehensive tests for the new behavior Fixes #7189 --- src/services/mcp/McpHub.ts | 234 ++++++++++++++++++++-- src/services/mcp/__tests__/McpHub.spec.ts | 156 +++++++++++++++ 2 files changed, 376 insertions(+), 14 deletions(-) diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 271c6e1fb3..e9bbf75614 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -151,6 +151,7 @@ export class McpHub { isConnecting: boolean = false private refCount: number = 0 // Reference counter for active clients private configChangeDebounceTimers: Map = new Map() + private activeToolExecutions: Map> = new Map() // Track active tool executions per server constructor(provider: ClineProvider) { this.providerRef = new WeakRef(provider) @@ -1169,6 +1170,16 @@ export class McpHub { } async restartConnection(serverName: string, source?: "global" | "project"): Promise { + // Check if there are active tool executions for this server + if (this.hasActiveToolExecutions(serverName, source)) { + console.log(`Skipping restart for ${serverName} - tools are currently executing`) + vscode.window.showWarningMessage( + t("mcp:errors.cannot_restart_tools_running", { serverName }) || + `Cannot restart server "${serverName}" while tools are running. Please wait for tool execution to complete.`, + ) + return + } + this.isConnecting = true // Check if MCP is globally enabled @@ -1357,6 +1368,20 @@ export class McpHub { } const serverSource = connection.server.source || "global" + + // Check if there are active tool executions for this server + if (this.hasActiveToolExecutions(serverName, serverSource)) { + // Show a warning message and don't proceed with the toggle + vscode.window.showWarningMessage( + t("mcp:errors.cannot_toggle_server_tools_running", { + serverName, + action: disabled ? "disable" : "enable", + }) || + `Cannot ${disabled ? "disable" : "enable"} server "${serverName}" while tools are running. Please wait for tool execution to complete.`, + ) + return + } + // Update the server config in the appropriate file await this.updateServerConfig(serverName, { disabled }, serverSource) @@ -1603,19 +1628,69 @@ export class McpHub { timeout = 60 * 1000 } - return await connection.client.request( - { - method: "tools/call", - params: { - name: toolName, - arguments: toolArguments, + // Track this tool execution as active + const serverKey = `${serverName}:${source || "global"}` + if (!this.activeToolExecutions.has(serverKey)) { + this.activeToolExecutions.set(serverKey, new Set()) + } + const executionId = `${toolName}:${Date.now()}` + this.activeToolExecutions.get(serverKey)!.add(executionId) + + try { + const result = await connection.client.request( + { + method: "tools/call", + params: { + name: toolName, + arguments: toolArguments, + }, }, - }, - CallToolResultSchema, - { - timeout, - }, - ) + CallToolResultSchema, + { + timeout, + }, + ) + + // Remove from active executions on success + this.activeToolExecutions.get(serverKey)?.delete(executionId) + if (this.activeToolExecutions.get(serverKey)?.size === 0) { + this.activeToolExecutions.delete(serverKey) + } + + return result + } catch (error) { + // Remove from active executions on error + this.activeToolExecutions.get(serverKey)?.delete(executionId) + if (this.activeToolExecutions.get(serverKey)?.size === 0) { + this.activeToolExecutions.delete(serverKey) + } + throw error + } + } + + /** + * Check if any tools are currently executing for a specific server + * @param serverName The name of the server to check + * @param source The source of the server (global or project) + * @returns true if there are active tool executions, false otherwise + */ + private hasActiveToolExecutions(serverName: string, source?: "global" | "project"): boolean { + const serverKey = `${serverName}:${source || "global"}` + const activeTools = this.activeToolExecutions.get(serverKey) + return activeTools ? activeTools.size > 0 : false + } + + /** + * Check if any tools are currently executing across all servers + * @returns true if there are any active tool executions, false otherwise + */ + private hasAnyActiveToolExecutions(): boolean { + for (const [, tools] of this.activeToolExecutions) { + if (tools.size > 0) { + return true + } + } + return false } /** @@ -1703,7 +1778,15 @@ export class McpHub { shouldAllow: boolean, ): Promise { try { - await this.updateServerToolList(serverName, source, toolName, "alwaysAllow", shouldAllow) + // Check if there are active tool executions for this server + if (this.hasActiveToolExecutions(serverName, source)) { + console.log(`Skipping server restart for ${serverName} - tools are currently executing`) + // Update the config file without triggering a restart + await this.updateServerToolListWithoutRestart(serverName, source, toolName, "alwaysAllow", shouldAllow) + } else { + // Normal flow - update and allow restart if needed + await this.updateServerToolList(serverName, source, toolName, "alwaysAllow", shouldAllow) + } } catch (error) { this.showErrorMessage( `Failed to toggle always allow for tool "${toolName}" on server "${serverName}" with source "${source}"`, @@ -1713,6 +1796,114 @@ export class McpHub { } } + /** + * Update server tool list without triggering a restart + * This is used when tools are actively running to prevent interruption + */ + private async updateServerToolListWithoutRestart( + serverName: string, + source: "global" | "project", + toolName: string, + listName: "alwaysAllow" | "disabledTools", + addTool: boolean, + ): Promise { + // Find the connection with matching name and source + const connection = this.findConnection(serverName, source) + + if (!connection) { + throw new Error(`Server ${serverName} with source ${source} not found`) + } + + // Determine the correct config path based on the source + let configPath: string + if (source === "project") { + // Get project MCP config path + const projectMcpPath = await this.getProjectMcpPath() + if (!projectMcpPath) { + throw new Error("Project MCP configuration file not found") + } + configPath = projectMcpPath + } else { + // Get global MCP settings path + configPath = await this.getMcpSettingsFilePath() + } + + // Normalize path for cross-platform compatibility + const normalizedPath = process.platform === "win32" ? configPath.replace(/\\/g, "/") : configPath + + // Read the appropriate config file + const content = await fs.readFile(normalizedPath, "utf-8") + const config = JSON.parse(content) + + if (!config.mcpServers) { + config.mcpServers = {} + } + + if (!config.mcpServers[serverName]) { + config.mcpServers[serverName] = { + type: "stdio", + command: "node", + args: [], // Default to an empty array; can be set later if needed + } + } + + if (!config.mcpServers[serverName][listName]) { + config.mcpServers[serverName][listName] = [] + } + + const targetList = config.mcpServers[serverName][listName] + const toolIndex = targetList.indexOf(toolName) + + if (addTool && toolIndex === -1) { + targetList.push(toolName) + } else if (!addTool && toolIndex !== -1) { + targetList.splice(toolIndex, 1) + } + + // Write the config file directly without triggering file watcher + // We'll temporarily disable the watcher to prevent restart + const watcherKey = source === "project" ? this.projectMcpWatcher : this.settingsWatcher + const wasWatcherActive = !!watcherKey + + if (wasWatcherActive && source === "project" && this.projectMcpWatcher) { + this.projectMcpWatcher.dispose() + this.projectMcpWatcher = undefined + } else if (wasWatcherActive && source === "global" && this.settingsWatcher) { + this.settingsWatcher.dispose() + this.settingsWatcher = undefined + } + + await fs.writeFile(normalizedPath, JSON.stringify(config, null, 2)) + + // Update the in-memory tool state without restarting + if (connection) { + // Update the tool's alwaysAllow or enabledForPrompt status in memory + const tools = connection.server.tools + if (tools) { + const tool = tools.find((t) => t.name === toolName) + if (tool) { + if (listName === "alwaysAllow") { + tool.alwaysAllow = addTool + } else if (listName === "disabledTools") { + tool.enabledForPrompt = !addTool + } + } + } + await this.notifyWebviewOfServerChanges() + } + + // Re-enable the watcher after a short delay + if (wasWatcherActive) { + setTimeout(async () => { + if (source === "project") { + await this.watchProjectMcpFile() + } else { + await this.watchMcpSettingsFile() + } + }, 1000) + } + } + async toggleToolEnabledForPrompt( serverName: string, source: "global" | "project", @@ -1723,7 +1914,22 @@ export class McpHub { // When isEnabled is true, we want to remove the tool from the disabledTools list. // When isEnabled is false, we want to add the tool to the disabledTools list. const addToolToDisabledList = !isEnabled - await this.updateServerToolList(serverName, source, toolName, "disabledTools", addToolToDisabledList) + + // Check if there are active tool executions for this server + if (this.hasActiveToolExecutions(serverName, source)) { + console.log(`Skipping server restart for ${serverName} - tools are currently executing`) + // Update the config file without triggering a restart + await this.updateServerToolListWithoutRestart( + serverName, + source, + toolName, + "disabledTools", + addToolToDisabledList, + ) + } else { + // Normal flow - update and allow restart if needed + await this.updateServerToolList(serverName, source, toolName, "disabledTools", addToolToDisabledList) + } } catch (error) { this.showErrorMessage(`Failed to update settings for tool ${toolName}`, error) throw error // Re-throw to ensure the error is properly handled diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index ebce2d5b2a..9165851199 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -908,6 +908,162 @@ describe("McpHub", () => { expect(writtenConfig.mcpServers["test-server"].alwaysAllow).toBeDefined() expect(writtenConfig.mcpServers["test-server"].alwaysAllow).toContain("new-tool") }) + + it("should skip server restart when toggling alwaysAllow during active tool execution", async () => { + // Mock fs.readFile to return existing config + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + alwaysAllow: [], + }, + }, + }), + ) + + // Create a connected server + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ + type: "stdio", + command: "node", + args: ["test.js"], + alwaysAllow: [], + }), + status: "connected", + source: "global", + tools: [ + { + name: "test-tool", + description: "Test tool", + alwaysAllow: false, + enabledForPrompt: true, + }, + ], + }, + client: { + request: vi.fn().mockImplementation(() => { + // Simulate long-running tool + return new Promise((resolve) => { + setTimeout(() => resolve({ content: [] }), 500) + }) + }), + } as any, + transport: {} as any, + } + + mcpHub.connections = [mockConnection] + + // Spy on console.log to verify the skip message + const consoleLogSpy = vi.spyOn(console, "log") + + // Start a tool execution + const toolPromise = mcpHub.callTool("test-server", "test-tool", {}) + + // While tool is running, toggle alwaysAllow + await mcpHub.toggleToolAlwaysAllow("test-server", "global", "test-tool", true) + + // Verify that the skip message was logged + expect(consoleLogSpy).toHaveBeenCalledWith( + "Skipping server restart for test-server - tools are currently executing", + ) + + // Verify that the config was updated + const writeCalls = vi.mocked(fs.writeFile).mock.calls + const lastWriteCall = writeCalls[writeCalls.length - 1] + const writtenConfig = JSON.parse(lastWriteCall[1] as string) + expect(writtenConfig.mcpServers["test-server"].alwaysAllow).toContain("test-tool") + + // Verify that the in-memory tool state was updated + const tool = mockConnection.server.tools?.find((t) => t.name === "test-tool") + expect(tool?.alwaysAllow).toBe(true) + + // Wait for tool to complete + await toolPromise + + // Verify that the server status is still connected + expect(mockConnection.server.status).toBe("connected") + + consoleLogSpy.mockRestore() + }) + + it("should prevent server restart when tools are running", async () => { + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ type: "stdio", command: "node", args: ["test.js"] }), + status: "connected", + source: "global", + }, + client: { + request: vi.fn().mockImplementation(() => { + // Simulate long-running tool + return new Promise((resolve) => { + setTimeout(() => resolve({ content: [] }), 500) + }) + }), + } as any, + transport: {} as any, + } + + mcpHub.connections = [mockConnection] + + // Start a tool execution + const toolPromise = mcpHub.callTool("test-server", "some-tool", {}) + + // Try to restart the server while tool is running + await mcpHub.restartConnection("test-server", "global") + + // Verify that the server was not restarted + expect(mcpHub.connections[0]).toBe(mockConnection) + expect(mockConnection.server.status).toBe("connected") + + // Wait for tool to complete + await toolPromise + }) + + it("should prevent toggling server disabled state when tools are running", async () => { + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ type: "stdio", command: "node", args: ["test.js"] }), + status: "connected", + source: "global", + disabled: false, + }, + client: { + request: vi.fn().mockImplementation(() => { + // Simulate long-running tool + return new Promise((resolve) => { + setTimeout(() => resolve({ content: [] }), 500) + }) + }), + } as any, + transport: {} as any, + } + + mcpHub.connections = [mockConnection] + + // Start a tool execution + const toolPromise = mcpHub.callTool("test-server", "some-tool", {}) + + // Try to disable the server while tool is running + await mcpHub.toggleServerDisabled("test-server", true, "global") + + // Verify that the server was not disabled + expect(mockConnection.server.disabled).toBe(false) + expect(mockConnection.server.status).toBe("connected") + + // Wait for tool to complete + await toolPromise + }) }) describe("toggleToolEnabledForPrompt", () => { From e8356585f80be26b235ad1ff315bf90d5f40f55e Mon Sep 17 00:00:00 2001 From: Roo Code Date: Fri, 22 Aug 2025 02:52:35 +0000 Subject: [PATCH 2/2] fix: address critical issues in MCP server restart handling - Add pendingRestarts processing after tool execution completes - Add cleanup for activeToolExecutions with timeout mechanism to prevent memory leaks - Refactor updateServerToolList to reduce code duplication using updateServerToolListInternal - Fix race condition in watcher re-enable logic by checking disposal state - Add comprehensive test coverage for all fixes --- src/services/mcp/McpHub.ts | 254 ++++++++++++---------- src/services/mcp/__tests__/McpHub.spec.ts | 240 +++++++++++++++++++- 2 files changed, 381 insertions(+), 113 deletions(-) diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index e9bbf75614..03ffd11538 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -152,6 +152,8 @@ export class McpHub { private refCount: number = 0 // Reference counter for active clients private configChangeDebounceTimers: Map = new Map() private activeToolExecutions: Map> = new Map() // Track active tool executions per server + private pendingRestarts: Set = new Set() // Track servers that need restart after tool execution + private toolExecutionTimeouts: Map = new Map() // Track timeouts for tool execution cleanup constructor(provider: ClineProvider) { this.providerRef = new WeakRef(provider) @@ -1172,10 +1174,13 @@ export class McpHub { async restartConnection(serverName: string, source?: "global" | "project"): Promise { // Check if there are active tool executions for this server if (this.hasActiveToolExecutions(serverName, source)) { - console.log(`Skipping restart for ${serverName} - tools are currently executing`) - vscode.window.showWarningMessage( - t("mcp:errors.cannot_restart_tools_running", { serverName }) || - `Cannot restart server "${serverName}" while tools are running. Please wait for tool execution to complete.`, + console.log(`Deferring restart for ${serverName} - tools are currently executing`) + // Add to pending restarts + const serverKey = `${serverName}:${source || "global"}` + this.pendingRestarts.add(serverKey) + vscode.window.showInformationMessage( + t("mcp:info.restart_deferred", { serverName }) || + `Server "${serverName}" will be restarted after tool execution completes.`, ) return } @@ -1636,6 +1641,15 @@ export class McpHub { const executionId = `${toolName}:${Date.now()}` this.activeToolExecutions.get(serverKey)!.add(executionId) + // Set up a timeout to clean up the execution tracking in case of hangs + const cleanupTimeout = setTimeout(() => { + this.cleanupToolExecution(serverKey, executionId) + }, timeout + 5000) // Add 5 seconds buffer after the tool timeout + + // Store the timeout so we can clear it if the tool completes normally + const timeoutKey = `${serverKey}:${executionId}` + this.toolExecutionTimeouts.set(timeoutKey, cleanupTimeout) + try { const result = await connection.client.request( { @@ -1651,23 +1665,60 @@ export class McpHub { }, ) + // Clear the cleanup timeout since the tool completed successfully + this.clearToolExecutionTimeout(timeoutKey) + // Remove from active executions on success - this.activeToolExecutions.get(serverKey)?.delete(executionId) - if (this.activeToolExecutions.get(serverKey)?.size === 0) { - this.activeToolExecutions.delete(serverKey) - } + this.cleanupToolExecution(serverKey, executionId) return result } catch (error) { + // Clear the cleanup timeout + this.clearToolExecutionTimeout(timeoutKey) + // Remove from active executions on error - this.activeToolExecutions.get(serverKey)?.delete(executionId) - if (this.activeToolExecutions.get(serverKey)?.size === 0) { - this.activeToolExecutions.delete(serverKey) - } + this.cleanupToolExecution(serverKey, executionId) + throw error } } + /** + * Clean up tool execution tracking and process pending restarts if needed + * @param serverKey The server key + * @param executionId The execution ID to clean up + */ + private cleanupToolExecution(serverKey: string, executionId: string): void { + // Remove from active executions + this.activeToolExecutions.get(serverKey)?.delete(executionId) + if (this.activeToolExecutions.get(serverKey)?.size === 0) { + this.activeToolExecutions.delete(serverKey) + + // Check if this server has pending restarts + if (this.pendingRestarts.has(serverKey)) { + this.pendingRestarts.delete(serverKey) + // Extract server name and source from the key + const [serverName, source] = serverKey.split(":") + // Process the pending restart + this.restartConnection(serverName, source as "global" | "project").catch((error) => { + console.error(`Failed to process pending restart for ${serverName}:`, error) + }) + } + } + } + + /** + * Clear a tool execution timeout + * @param timeoutKey The timeout key to clear + */ + private clearToolExecutionTimeout(timeoutKey: string): void { + const timeout = this.toolExecutionTimeouts.get(timeoutKey) + if (timeout) { + clearTimeout(timeout) + this.toolExecutionTimeouts.delete(timeoutKey) + } + } + /** * Check if any tools are currently executing for a specific server * @param serverName The name of the server to check @@ -1708,6 +1759,27 @@ export class McpHub { toolName: string, listName: "alwaysAllow" | "disabledTools", addTool: boolean, + ): Promise { + // Use the common implementation with restart enabled + await this.updateServerToolListInternal(serverName, source, toolName, listName, addTool, true) + } + + /** + * Internal implementation for updating server tool lists + * @param serverName The name of the server to update + * @param source Whether to update the global or project config + * @param toolName The name of the tool to add or remove + * @param listName The name of the list to modify ("alwaysAllow" or "disabledTools") + * @param addTool Whether to add (true) or remove (false) the tool from the list + * @param allowRestart Whether to allow file watchers to trigger restart + */ + private async updateServerToolListInternal( + serverName: string, + source: "global" | "project", + toolName: string, + listName: "alwaysAllow" | "disabledTools", + addTool: boolean, + allowRestart: boolean, ): Promise { // Find the connection with matching name and source const connection = this.findConnection(serverName, source) @@ -1731,7 +1803,6 @@ export class McpHub { } // Normalize path for cross-platform compatibility - // Use a consistent path format for both reading and writing const normalizedPath = process.platform === "win32" ? configPath.replace(/\\/g, "/") : configPath // Read the appropriate config file @@ -1763,12 +1834,57 @@ export class McpHub { targetList.splice(toolIndex, 1) } + // Handle file watcher based on allowRestart flag + let watcherToRestore: vscode.FileSystemWatcher | undefined + if (!allowRestart) { + // Temporarily disable the watcher to prevent restart + if (source === "project" && this.projectMcpWatcher) { + watcherToRestore = this.projectMcpWatcher + this.projectMcpWatcher.dispose() + this.projectMcpWatcher = undefined + } else if (source === "global" && this.settingsWatcher) { + watcherToRestore = this.settingsWatcher + this.settingsWatcher.dispose() + this.settingsWatcher = undefined + } + } + await fs.writeFile(normalizedPath, JSON.stringify(config, null, 2)) if (connection) { - connection.server.tools = await this.fetchToolsList(serverName, source) + if (allowRestart) { + // Normal flow - fetch fresh tool list + connection.server.tools = await this.fetchToolsList(serverName, source) + } else { + // Update the in-memory tool state without restarting + const tools = connection.server.tools + if (tools) { + const tool = tools.find((t) => t.name === toolName) + if (tool) { + if (listName === "alwaysAllow") { + tool.alwaysAllow = addTool + } else if (listName === "disabledTools") { + tool.enabledForPrompt = !addTool + } + } + } + } await this.notifyWebviewOfServerChanges() } + + // Re-enable the watcher after a short delay if it was disabled + if (!allowRestart && watcherToRestore) { + setTimeout(async () => { + // Check if not disposed during the timeout + if (!this.isDisposed) { + if (source === "project" && !this.projectMcpWatcher) { + await this.watchProjectMcpFile() + } else if (source === "global" && !this.settingsWatcher) { + await this.watchMcpSettingsFile() + } + } + }, 1000) + } } async toggleToolAlwaysAllow( @@ -1780,7 +1896,7 @@ export class McpHub { try { // Check if there are active tool executions for this server if (this.hasActiveToolExecutions(serverName, source)) { - console.log(`Skipping server restart for ${serverName} - tools are currently executing`) + console.log(`Deferring config update for ${serverName} - tools are currently executing`) // Update the config file without triggering a restart await this.updateServerToolListWithoutRestart(serverName, source, toolName, "alwaysAllow", shouldAllow) } else { @@ -1807,101 +1923,8 @@ export class McpHub { listName: "alwaysAllow" | "disabledTools", addTool: boolean, ): Promise { - // Find the connection with matching name and source - const connection = this.findConnection(serverName, source) - - if (!connection) { - throw new Error(`Server ${serverName} with source ${source} not found`) - } - - // Determine the correct config path based on the source - let configPath: string - if (source === "project") { - // Get project MCP config path - const projectMcpPath = await this.getProjectMcpPath() - if (!projectMcpPath) { - throw new Error("Project MCP configuration file not found") - } - configPath = projectMcpPath - } else { - // Get global MCP settings path - configPath = await this.getMcpSettingsFilePath() - } - - // Normalize path for cross-platform compatibility - const normalizedPath = process.platform === "win32" ? configPath.replace(/\\/g, "/") : configPath - - // Read the appropriate config file - const content = await fs.readFile(normalizedPath, "utf-8") - const config = JSON.parse(content) - - if (!config.mcpServers) { - config.mcpServers = {} - } - - if (!config.mcpServers[serverName]) { - config.mcpServers[serverName] = { - type: "stdio", - command: "node", - args: [], // Default to an empty array; can be set later if needed - } - } - - if (!config.mcpServers[serverName][listName]) { - config.mcpServers[serverName][listName] = [] - } - - const targetList = config.mcpServers[serverName][listName] - const toolIndex = targetList.indexOf(toolName) - - if (addTool && toolIndex === -1) { - targetList.push(toolName) - } else if (!addTool && toolIndex !== -1) { - targetList.splice(toolIndex, 1) - } - - // Write the config file directly without triggering file watcher - // We'll temporarily disable the watcher to prevent restart - const watcherKey = source === "project" ? this.projectMcpWatcher : this.settingsWatcher - const wasWatcherActive = !!watcherKey - - if (wasWatcherActive && source === "project" && this.projectMcpWatcher) { - this.projectMcpWatcher.dispose() - this.projectMcpWatcher = undefined - } else if (wasWatcherActive && source === "global" && this.settingsWatcher) { - this.settingsWatcher.dispose() - this.settingsWatcher = undefined - } - - await fs.writeFile(normalizedPath, JSON.stringify(config, null, 2)) - - // Update the in-memory tool state without restarting - if (connection) { - // Update the tool's alwaysAllow or enabledForPrompt status in memory - const tools = connection.server.tools - if (tools) { - const tool = tools.find((t) => t.name === toolName) - if (tool) { - if (listName === "alwaysAllow") { - tool.alwaysAllow = addTool - } else if (listName === "disabledTools") { - tool.enabledForPrompt = !addTool - } - } - } - await this.notifyWebviewOfServerChanges() - } - - // Re-enable the watcher after a short delay - if (wasWatcherActive) { - setTimeout(async () => { - if (source === "project") { - await this.watchProjectMcpFile() - } else { - await this.watchMcpSettingsFile() - } - }, 1000) - } + // Use the common implementation + await this.updateServerToolListInternal(serverName, source, toolName, listName, addTool, false) } async toggleToolEnabledForPrompt( @@ -1917,7 +1940,7 @@ export class McpHub { // Check if there are active tool executions for this server if (this.hasActiveToolExecutions(serverName, source)) { - console.log(`Skipping server restart for ${serverName} - tools are currently executing`) + console.log(`Deferring config update for ${serverName} - tools are currently executing`) // Update the config file without triggering a restart await this.updateServerToolListWithoutRestart( serverName, @@ -2004,6 +2027,15 @@ export class McpHub { } this.configChangeDebounceTimers.clear() + // Clear all tool execution timeouts + for (const timeout of this.toolExecutionTimeouts.values()) { + clearTimeout(timeout) + } + this.toolExecutionTimeouts.clear() + + // Clear pending restarts + this.pendingRestarts.clear() + this.removeAllFileWatchers() for (const connection of this.connections) { try { diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index 9165851199..37d4b44476 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -968,9 +968,9 @@ describe("McpHub", () => { // While tool is running, toggle alwaysAllow await mcpHub.toggleToolAlwaysAllow("test-server", "global", "test-tool", true) - // Verify that the skip message was logged + // Verify that the defer message was logged expect(consoleLogSpy).toHaveBeenCalledWith( - "Skipping server restart for test-server - tools are currently executing", + "Deferring config update for test-server - tools are currently executing", ) // Verify that the config was updated @@ -1064,6 +1064,126 @@ describe("McpHub", () => { // Wait for tool to complete await toolPromise }) + + it("should process pending restarts after tool execution completes", async () => { + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ type: "stdio", command: "node", args: ["test.js"] }), + status: "connected", + source: "global", + }, + client: { + request: vi.fn().mockImplementation(() => { + // Simulate tool execution + return new Promise((resolve) => { + setTimeout(() => resolve({ content: [] }), 100) + }) + }), + close: vi.fn().mockResolvedValue(undefined), + } as any, + transport: { + close: vi.fn().mockResolvedValue(undefined), + } as any, + } + + mcpHub.connections = [mockConnection] + + // Start a tool execution + const toolPromise = mcpHub.callTool("test-server", "some-tool", {}) + + // Try to restart while tool is running - should be deferred + await mcpHub.restartConnection("test-server", "global") + + // Verify server is still connected (restart was deferred) + expect(mockConnection.server.status).toBe("connected") + + // Wait for tool to complete + await toolPromise + + // Give time for pending restart to process + await new Promise((resolve) => setTimeout(resolve, 200)) + + // Verify that the connection was restarted (old connection closed) + expect(mockConnection.client.close).toHaveBeenCalled() + expect(mockConnection.transport.close).toHaveBeenCalled() + }) + + it("should clean up activeToolExecutions on timeout", async () => { + vi.useFakeTimers() + + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ type: "stdio", command: "node", args: ["test.js"], timeout: 1 }), + status: "connected", + source: "global", + }, + client: { + request: vi.fn().mockImplementation(() => { + // Simulate a hanging tool that never resolves + return new Promise(() => {}) + }), + } as any, + transport: {} as any, + } + + mcpHub.connections = [mockConnection] + + // Start a tool execution that will hang + const toolPromise = mcpHub.callTool("test-server", "hanging-tool", {}).catch(() => {}) + + // Fast-forward time to trigger the cleanup timeout (1 second tool timeout + 5 second buffer) + vi.advanceTimersByTime(6000) + + // Check that activeToolExecutions has been cleaned up + const serverKey = "test-server:global" + const activeExecutions = mcpHub["activeToolExecutions"].get(serverKey) + expect(activeExecutions).toBeUndefined() + + vi.useRealTimers() + }) + + it("should not create duplicate entries in pendingRestarts", async () => { + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ type: "stdio", command: "node", args: ["test.js"] }), + status: "connected", + source: "global", + }, + client: { + request: vi.fn().mockImplementation(() => { + // Simulate long-running tool + return new Promise((resolve) => { + setTimeout(() => resolve({ content: [] }), 500) + }) + }), + } as any, + transport: {} as any, + } + + mcpHub.connections = [mockConnection] + + // Start a tool execution + const toolPromise = mcpHub.callTool("test-server", "some-tool", {}) + + // Try to restart multiple times while tool is running + await mcpHub.restartConnection("test-server", "global") + await mcpHub.restartConnection("test-server", "global") + await mcpHub.restartConnection("test-server", "global") + + // Check that pendingRestarts only has one entry + const serverKey = "test-server:global" + expect(mcpHub["pendingRestarts"].has(serverKey)).toBe(true) + expect(mcpHub["pendingRestarts"].size).toBe(1) + + // Wait for tool to complete + await toolPromise + }) }) describe("toggleToolEnabledForPrompt", () => { @@ -2300,4 +2420,120 @@ describe("McpHub", () => { ) }) }) + + describe("Race condition fixes", () => { + it("should check disposal state before re-enabling watchers", async () => { + // Mock fs.readFile to return existing config + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + alwaysAllow: [], + }, + }, + }), + ) + + // Create a connected server + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ + type: "stdio", + command: "node", + args: ["test.js"], + alwaysAllow: [], + }), + status: "connected", + source: "global", + tools: [ + { + name: "test-tool", + description: "Test tool", + alwaysAllow: false, + enabledForPrompt: true, + }, + ], + }, + client: { + request: vi.fn().mockImplementation(() => { + // Simulate tool execution + return new Promise((resolve) => { + setTimeout(() => resolve({ content: [] }), 100) + }) + }), + } as any, + transport: {} as any, + } + + mcpHub.connections = [mockConnection] + + // Start a tool execution + const toolPromise = mcpHub.callTool("test-server", "test-tool", {}) + + // Toggle alwaysAllow while tool is running (will disable watcher temporarily) + await mcpHub.toggleToolAlwaysAllow("test-server", "global", "test-tool", true) + + // Dispose the hub before the watcher re-enable timeout + await mcpHub.dispose() + + // Wait for tool to complete + await toolPromise + + // Wait for the watcher re-enable timeout to fire + await new Promise((resolve) => setTimeout(resolve, 1100)) + + // Verify that no watcher was re-enabled after disposal + expect(mcpHub["settingsWatcher"]).toBeUndefined() + expect(mcpHub["projectMcpWatcher"]).toBeUndefined() + }) + + it("should handle concurrent tool executions correctly", async () => { + // Create a connected server + const mockConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ + type: "stdio", + command: "node", + args: ["test.js"], + }), + status: "connected", + source: "global", + }, + client: { + request: vi.fn().mockImplementation(() => { + // Simulate tool execution + return new Promise((resolve) => { + setTimeout(() => resolve({ content: [] }), 200) + }) + }), + } as any, + transport: {} as any, + } + + mcpHub.connections = [mockConnection] + + // Start multiple tool executions + const tool1Promise = mcpHub.callTool("test-server", "tool1", {}) + const tool2Promise = mcpHub.callTool("test-server", "tool2", {}) + + // Verify both tools are being tracked as active + const serverKey = "test-server:global" + const activeExecutions = mcpHub["activeToolExecutions"].get(serverKey) + expect(activeExecutions?.size).toBe(2) + + // Wait for tools to complete + await Promise.all([tool1Promise, tool2Promise]) + + // Verify that activeToolExecutions has been cleaned up + const activeExecutionsAfter = mcpHub["activeToolExecutions"].get(serverKey) + expect(activeExecutionsAfter).toBeUndefined() + }) + }) })