diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts index e8dbffb62d..4a0b2ebc6f 100644 --- a/packages/types/src/global-settings.ts +++ b/packages/types/src/global-settings.ts @@ -133,6 +133,7 @@ export const globalSettingsSchema = z.object({ mcpEnabled: z.boolean().optional(), enableMcpServerCreation: z.boolean().optional(), + mcpServers: z.record(z.string(), z.any()).optional(), remoteControlEnabled: z.boolean().optional(), diff --git a/src/extension/__tests__/api.spec.ts b/src/extension/__tests__/api.spec.ts new file mode 100644 index 0000000000..b098a128eb --- /dev/null +++ b/src/extension/__tests__/api.spec.ts @@ -0,0 +1,206 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" +import * as vscode from "vscode" +import { API } from "../api" +import { ClineProvider } from "../../core/webview/ClineProvider" +import { Package } from "../../shared/package" + +// Mock vscode module +vi.mock("vscode", () => ({ + commands: { + executeCommand: vi.fn(), + }, + workspace: { + getConfiguration: vi.fn(() => ({ + update: vi.fn(), + })), + }, + window: { + showErrorMessage: vi.fn(), + showInformationMessage: vi.fn(), + showWarningMessage: vi.fn(), + createTextEditorDecorationType: vi.fn().mockReturnValue({ + dispose: vi.fn(), + }), + }, + ConfigurationTarget: { + Global: 1, + }, +})) + +describe("API", () => { + let api: API + let mockOutputChannel: any + let mockProvider: any + let mockMcpHub: any + + beforeEach(() => { + // Create mock output channel + mockOutputChannel = { + appendLine: vi.fn(), + } + + // Create mock MCP hub + mockMcpHub = { + initializeRuntimeMcpServers: vi.fn(), + } + + // Create mock provider + mockProvider = { + context: { + extension: { + packageJSON: { + version: "1.0.0", + }, + }, + }, + setValues: vi.fn(), + removeClineFromStack: vi.fn(), + postStateToWebview: vi.fn(), + postMessageToWebview: vi.fn(), + createTask: vi.fn().mockResolvedValue({ taskId: "test-task-id" }), + getMcpHub: vi.fn().mockReturnValue(mockMcpHub), + on: vi.fn(), + off: vi.fn(), + } + + // Create API instance + api = new API(mockOutputChannel, mockProvider as any) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe("startNewTask", () => { + it("should initialize runtime MCP servers when mcpServers is provided in configuration", async () => { + const configuration = { + apiProvider: "openai" as const, + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + }, + }, + } + + const taskId = await api.startNewTask({ + configuration, + text: "Test task", + }) + + // Verify MCP hub was retrieved + expect(mockProvider.getMcpHub).toHaveBeenCalled() + + // Verify runtime MCP servers were initialized + expect(mockMcpHub.initializeRuntimeMcpServers).toHaveBeenCalledWith(configuration.mcpServers) + + // Verify other methods were called + expect(mockProvider.setValues).toHaveBeenCalledWith(configuration) + expect(mockProvider.createTask).toHaveBeenCalled() + expect(taskId).toBe("test-task-id") + }) + + it("should not initialize MCP servers when mcpServers is not provided", async () => { + const configuration = { + apiProvider: "openai" as const, + } + + await api.startNewTask({ + configuration, + text: "Test task", + }) + + // Verify MCP hub was not retrieved + expect(mockProvider.getMcpHub).not.toHaveBeenCalled() + + // Verify runtime MCP servers were not initialized + expect(mockMcpHub.initializeRuntimeMcpServers).not.toHaveBeenCalled() + + // Verify other methods were still called + expect(mockProvider.setValues).toHaveBeenCalledWith(configuration) + expect(mockProvider.createTask).toHaveBeenCalled() + }) + + it("should handle when MCP hub is not available", async () => { + // Make getMcpHub return undefined + mockProvider.getMcpHub.mockReturnValue(undefined) + + const configuration = { + apiProvider: "openai" as const, + mcpServers: { + "test-server": { + type: "stdio", + command: "node", + args: ["test.js"], + }, + }, + } + + // Should not throw an error + const taskId = await api.startNewTask({ + configuration, + text: "Test task", + }) + + // Verify MCP hub was retrieved + expect(mockProvider.getMcpHub).toHaveBeenCalled() + + // Verify runtime MCP servers were not initialized (since hub is undefined) + expect(mockMcpHub.initializeRuntimeMcpServers).not.toHaveBeenCalled() + + // Verify other methods were still called + expect(mockProvider.setValues).toHaveBeenCalledWith(configuration) + expect(mockProvider.createTask).toHaveBeenCalled() + expect(taskId).toBe("test-task-id") + }) + + it("should handle complex MCP server configurations", async () => { + const configuration = { + apiProvider: "openai" as const, + mcpServers: { + "stdio-server": { + type: "stdio", + command: "node", + args: ["server.js"], + env: { NODE_ENV: "production" }, + disabled: false, + alwaysAllow: ["tool1", "tool2"], + }, + "sse-server": { + type: "sse", + url: "http://localhost:8080/sse", + headers: { Authorization: "Bearer token" }, + disabled: true, + }, + }, + } + + await api.startNewTask({ + configuration, + text: "Test task", + }) + + // Verify runtime MCP servers were initialized with the full configuration + expect(mockMcpHub.initializeRuntimeMcpServers).toHaveBeenCalledWith(configuration.mcpServers) + }) + + it("should handle empty mcpServers object", async () => { + const configuration = { + apiProvider: "openai" as const, + mcpServers: {}, + } + + await api.startNewTask({ + configuration, + text: "Test task", + }) + + // Verify MCP hub was retrieved + expect(mockProvider.getMcpHub).toHaveBeenCalled() + + // Verify runtime MCP servers were initialized with empty object + expect(mockMcpHub.initializeRuntimeMcpServers).toHaveBeenCalledWith({}) + }) + }) +}) diff --git a/src/extension/api.ts b/src/extension/api.ts index 9c38aabfdb..b157f3625e 100644 --- a/src/extension/api.ts +++ b/src/extension/api.ts @@ -129,6 +129,15 @@ export class API extends EventEmitter implements RooCodeAPI { } if (configuration) { + // Handle MCP servers if provided in configuration + if (configuration.mcpServers && typeof configuration.mcpServers === "object") { + const mcpHub = provider.getMcpHub() + if (mcpHub) { + // Initialize runtime MCP servers + await mcpHub.initializeRuntimeMcpServers(configuration.mcpServers) + } + } + await provider.setValues(configuration) if (configuration.allowedCommands) { diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 6ec5b839e8..6c6ee851c5 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -151,6 +151,7 @@ export class McpHub { isConnecting: boolean = false private refCount: number = 0 // Reference counter for active clients private configChangeDebounceTimers: Map = new Map() + private runtimeServers: Record = {} // Store runtime MCP servers constructor(provider: ClineProvider) { this.providerRef = new WeakRef(provider) @@ -545,6 +546,21 @@ export class McpHub { } } + /** + * Initialize runtime MCP servers passed programmatically (e.g., through IPC API) + * @param servers Record of server configurations + */ + public async initializeRuntimeMcpServers(servers: Record): Promise { + this.runtimeServers = servers || {} + + // Initialize runtime servers with "runtime" source + try { + await this.updateServerConnections(this.runtimeServers, "runtime" as any, false) + } catch (error) { + this.showErrorMessage("Failed to initialize runtime MCP servers", error) + } + } + private async initializeGlobalMcpServers(): Promise { await this.initializeMcpServers("global") } @@ -583,7 +599,7 @@ export class McpHub { private createPlaceholderConnection( name: string, config: z.infer, - source: "global" | "project", + source: "global" | "project" | "runtime", reason: DisableReason, ): DisconnectedMcpConnection { return { @@ -618,7 +634,7 @@ export class McpHub { private async connectToServer( name: string, config: z.infer, - source: "global" | "project" = "global", + source: "global" | "project" | "runtime" = "global", ): Promise { // Remove existing connection if it exists with the same source await this.deleteConnection(name, source) @@ -884,10 +900,10 @@ export class McpHub { /** * Helper method to find a connection by server name and source * @param serverName The name of the server to find - * @param source Optional source to filter by (global or project) + * @param source Optional source to filter by (global, project, or runtime) * @returns The matching connection or undefined if not found */ - private findConnection(serverName: string, source?: "global" | "project"): McpConnection | undefined { + private findConnection(serverName: string, source?: "global" | "project" | "runtime"): McpConnection | undefined { // If source is specified, only find servers with that source if (source !== undefined) { return this.connections.find((conn) => conn.server.name === serverName && conn.server.source === source) @@ -906,7 +922,7 @@ export class McpHub { ) } - private async fetchToolsList(serverName: string, source?: "global" | "project"): Promise { + private async fetchToolsList(serverName: string, source?: "global" | "project" | "runtime"): Promise { try { // Use the helper method to find the connection const connection = this.findConnection(serverName, source) @@ -963,7 +979,10 @@ export class McpHub { } } - private async fetchResourcesList(serverName: string, source?: "global" | "project"): Promise { + private async fetchResourcesList( + serverName: string, + source?: "global" | "project" | "runtime", + ): Promise { try { const connection = this.findConnection(serverName, source) if (!connection || connection.type !== "connected") { @@ -979,7 +998,7 @@ export class McpHub { private async fetchResourceTemplatesList( serverName: string, - source?: "global" | "project", + source?: "global" | "project" | "runtime", ): Promise { try { const connection = this.findConnection(serverName, source) @@ -997,7 +1016,7 @@ export class McpHub { } } - async deleteConnection(name: string, source?: "global" | "project"): Promise { + async deleteConnection(name: string, source?: "global" | "project" | "runtime"): Promise { // Clean up file watchers for this server this.removeFileWatchersForServer(name) @@ -1027,7 +1046,7 @@ export class McpHub { async updateServerConnections( newServers: Record, - source: "global" | "project" = "global", + source: "global" | "project" | "runtime" = "global", manageConnectingState: boolean = true, ): Promise { if (manageConnectingState) { @@ -1097,7 +1116,7 @@ export class McpHub { private setupFileWatcher( name: string, config: z.infer, - source: "global" | "project" = "global", + source: "global" | "project" | "runtime" = "global", ) { // Initialize an empty array for this server if it doesn't exist if (!this.fileWatchers.has(name)) { @@ -1170,7 +1189,7 @@ export class McpHub { } } - async restartConnection(serverName: string, source?: "global" | "project"): Promise { + async restartConnection(serverName: string, source?: "global" | "project" | "runtime"): Promise { this.isConnecting = true // Check if MCP is globally enabled @@ -1271,6 +1290,10 @@ export class McpHub { // This ensures proper initialization including fetching tools, resources, etc. await this.initializeMcpServers("global") await this.initializeMcpServers("project") + // Re-initialize runtime servers if any + if (Object.keys(this.runtimeServers).length > 0) { + await this.initializeRuntimeMcpServers(this.runtimeServers) + } await delay(100) @@ -1302,9 +1325,13 @@ export class McpHub { } } - // Sort connections: first project servers in their defined order, then global servers in their defined order - // This ensures that when servers have the same name, project servers are prioritized + // Sort connections: first runtime servers, then project servers, then global servers + // This ensures that runtime servers have highest priority const sortedConnections = [...this.connections].sort((a, b) => { + // Runtime servers come first + if (a.server.source === "runtime" && b.server.source !== "runtime") return -1 + if (b.server.source === "runtime" && a.server.source !== "runtime") return 1 + const aIsGlobal = a.server.source === "global" || !a.server.source const bIsGlobal = b.server.source === "global" || !b.server.source @@ -1349,7 +1376,7 @@ export class McpHub { public async toggleServerDisabled( serverName: string, disabled: boolean, - source?: "global" | "project", + source?: "global" | "project" | "runtime", ): Promise { try { // Find the connection to determine if it's a global or project server @@ -1410,8 +1437,20 @@ export class McpHub { private async updateServerConfig( serverName: string, configUpdate: Record, - source: "global" | "project" = "global", + source: "global" | "project" | "runtime" = "global", ): Promise { + // Runtime servers are not persisted to files + if (source === "runtime") { + // Update the runtime server configuration in memory + if (this.runtimeServers[serverName]) { + this.runtimeServers[serverName] = { + ...this.runtimeServers[serverName], + ...configUpdate, + } + } + return + } + // Determine which config file to update let configPath: string if (source === "project") { @@ -1473,7 +1512,7 @@ export class McpHub { public async updateServerTimeout( serverName: string, timeout: number, - source?: "global" | "project", + source?: "global" | "project" | "runtime", ): Promise { try { // Find the connection to determine if it's a global or project server @@ -1492,7 +1531,7 @@ export class McpHub { } } - public async deleteServer(serverName: string, source?: "global" | "project"): Promise { + public async deleteServer(serverName: string, source?: "global" | "project" | "runtime"): Promise { try { // Find the connection to determine if it's a global or project server const connection = this.findConnection(serverName, source) @@ -1501,6 +1540,16 @@ export class McpHub { } const serverSource = connection.server.source || "global" + + // Runtime servers are not persisted to files, just remove from connections + if (serverSource === "runtime") { + await this.deleteConnection(serverName, "runtime") + delete this.runtimeServers[serverName] + await this.notifyWebviewOfServerChanges() + vscode.window.showInformationMessage(t("mcp:info.server_deleted", { serverName })) + return + } + // Determine config file based on server source const isProjectServer = serverSource === "project" let configPath: string @@ -1560,7 +1609,11 @@ export class McpHub { } } - async readResource(serverName: string, uri: string, source?: "global" | "project"): Promise { + async readResource( + serverName: string, + uri: string, + source?: "global" | "project" | "runtime", + ): Promise { const connection = this.findConnection(serverName, source) if (!connection || connection.type !== "connected") { throw new Error(`No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}`) @@ -1583,7 +1636,7 @@ export class McpHub { serverName: string, toolName: string, toolArguments?: Record, - source?: "global" | "project", + source?: "global" | "project" | "runtime", ): Promise { const connection = this.findConnection(serverName, source) if (!connection || connection.type !== "connected") { @@ -1631,11 +1684,36 @@ export class McpHub { */ private async updateServerToolList( serverName: string, - source: "global" | "project", + source: "global" | "project" | "runtime", toolName: string, listName: "alwaysAllow" | "disabledTools", addTool: boolean, ): Promise { + // Runtime servers are not persisted to files + if (source === "runtime") { + if (this.runtimeServers[serverName]) { + if (!this.runtimeServers[serverName][listName]) { + this.runtimeServers[serverName][listName] = [] + } + const targetList = this.runtimeServers[serverName][listName] + const toolIndex = targetList.indexOf(toolName) + + if (addTool && toolIndex === -1) { + targetList.push(toolName) + } else if (!addTool && toolIndex !== -1) { + targetList.splice(toolIndex, 1) + } + + // Update the connection + const connection = this.findConnection(serverName, source) + if (connection) { + connection.server.tools = await this.fetchToolsList(serverName, source) + await this.notifyWebviewOfServerChanges() + } + } + return + } + // Find the connection with matching name and source const connection = this.findConnection(serverName, source) @@ -1700,7 +1778,7 @@ export class McpHub { async toggleToolAlwaysAllow( serverName: string, - source: "global" | "project", + source: "global" | "project" | "runtime", toolName: string, shouldAllow: boolean, ): Promise { @@ -1717,7 +1795,7 @@ export class McpHub { async toggleToolEnabledForPrompt( serverName: string, - source: "global" | "project", + source: "global" | "project" | "runtime", toolName: string, isEnabled: boolean, ): Promise { diff --git a/src/shared/mcp.ts b/src/shared/mcp.ts index ef1d51bad3..88c112da89 100644 --- a/src/shared/mcp.ts +++ b/src/shared/mcp.ts @@ -15,7 +15,7 @@ export type McpServer = { resourceTemplates?: McpResourceTemplate[] disabled?: boolean timeout?: number - source?: "global" | "project" + source?: "global" | "project" | "runtime" projectPath?: string instructions?: string }