Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,23 @@ export const webviewMessageHandler = async (
case "mcpEnabled":
const mcpEnabled = message.bool ?? true
await updateGlobalState("mcpEnabled", mcpEnabled)

// If MCP is being disabled, disconnect all servers
const mcpHubInstance = provider.getMcpHub()
if (!mcpEnabled && mcpHubInstance) {
// Disconnect all existing connections
const existingConnections = [...mcpHubInstance.connections]
for (const conn of existingConnections) {
await mcpHubInstance.deleteConnection(conn.server.name, conn.server.source)
}

// Re-initialize servers to track them in disconnected state
await mcpHubInstance.refreshAllConnections()
} else if (mcpEnabled && mcpHubInstance) {
// If MCP is being enabled, reconnect all servers
await mcpHubInstance.refreshAllConnections()
}

await provider.postStateToWebview()
break
case "enableMcpServerCreation":
Expand Down
106 changes: 102 additions & 4 deletions src/services/mcp/McpHub.ts
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ export class McpHub {
const result = McpSettingsSchema.safeParse(config)

if (result.success) {
// Pass all servers including disabled ones - they'll be handled in updateServerConnections
await this.updateServerConnections(result.data.mcpServers || {}, source, false)
} else {
const errorMessages = result.error.errors
Expand Down Expand Up @@ -560,6 +561,54 @@ export class McpHub {
// Remove existing connection if it exists with the same source
await this.deleteConnection(name, source)

// Check if MCP is globally enabled
const provider = this.providerRef.deref()
if (provider) {
const state = await provider.getState()
const mcpEnabled = state.mcpEnabled ?? true

// Skip connecting if MCP is globally disabled
if (!mcpEnabled) {
// Still create a connection object to track the server, but don't actually connect
const connection: McpConnection = {
server: {
name,
config: JSON.stringify(config),
status: "disconnected",
disabled: config.disabled,
source,
projectPath:
source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined,
errorHistory: [],
},
client: null as any, // We won't actually create a client when MCP is disabled
transport: null as any, // We won't actually create a transport when MCP is disabled
}
this.connections.push(connection)
return
}
}

// Skip connecting to disabled servers
if (config.disabled) {
// Still create a connection object to track the server, but don't actually connect
const connection: McpConnection = {
server: {
name,
config: JSON.stringify(config),
status: "disconnected",
disabled: true,
source,
projectPath: source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined,
errorHistory: [],
},
client: null as any, // We won't actually create a client for disabled servers
transport: null as any, // We won't actually create a transport for disabled servers
}
this.connections.push(connection)
return
}

try {
const client = new Client(
{
Expand Down Expand Up @@ -975,15 +1024,21 @@ export class McpHub {
if (!currentConnection) {
// New server
try {
this.setupFileWatcher(name, validatedConfig, source)
// Only setup file watcher for enabled servers
if (!validatedConfig.disabled) {
this.setupFileWatcher(name, validatedConfig, source)
}
await this.connectToServer(name, validatedConfig, source)
} catch (error) {
this.showErrorMessage(`Failed to connect to new MCP server ${name}`, error)
}
} else if (!deepEqual(JSON.parse(currentConnection.server.config), config)) {
// Existing server with changed config
try {
this.setupFileWatcher(name, validatedConfig, source)
// Only setup file watcher for enabled servers
if (!validatedConfig.disabled) {
this.setupFileWatcher(name, validatedConfig, source)
}
await this.deleteConnection(name, source)
await this.connectToServer(name, validatedConfig, source)
} catch (error) {
Expand Down Expand Up @@ -1073,6 +1128,16 @@ export class McpHub {
return
}

// Check if MCP is globally enabled
const state = await provider.getState()
const mcpEnabled = state.mcpEnabled ?? true

// Skip restarting if MCP is globally disabled
if (!mcpEnabled) {
this.isConnecting = false
return
}

// Get existing connection and update its status
const connection = this.findConnection(serverName, source)
const config = connection?.server.config
Expand Down Expand Up @@ -1111,6 +1176,29 @@ export class McpHub {
return
}

// Check if MCP is globally enabled
const provider = this.providerRef.deref()
if (provider) {
const state = await provider.getState()
const mcpEnabled = state.mcpEnabled ?? true

// Skip refreshing if MCP is globally disabled
if (!mcpEnabled) {
// Clear all existing connections
const existingConnections = [...this.connections]
for (const conn of existingConnections) {
await this.deleteConnection(conn.server.name, conn.server.source)
}

// Still initialize servers to track them, but they won't connect
await this.initializeMcpServers("global")
await this.initializeMcpServers("project")

await this.notifyWebviewOfServerChanges()
return
}
}

this.isConnecting = true
vscode.window.showInformationMessage(t("mcp:info.refreshing_all"))

Expand Down Expand Up @@ -1257,8 +1345,18 @@ export class McpHub {
try {
connection.server.disabled = disabled

// Only refresh capabilities if connected
if (connection.server.status === "connected") {
// If disabling a connected server, disconnect it
if (disabled && connection.server.status === "connected") {
await this.deleteConnection(serverName, serverSource)
// Re-add as a disabled connection
await this.connectToServer(serverName, JSON.parse(connection.server.config), serverSource)
} else if (!disabled && connection.server.status === "disconnected") {
// If enabling a disabled server, connect it
const config = JSON.parse(connection.server.config)
await this.deleteConnection(serverName, serverSource)
await this.connectToServer(serverName, config, serverSource)
} else if (connection.server.status === "connected") {
// Only refresh capabilities if connected
connection.server.tools = await this.fetchToolsList(serverName, serverSource)
connection.server.resources = await this.fetchResourcesList(serverName, serverSource)
connection.server.resourceTemplates = await this.fetchResourceTemplatesList(
Expand Down
117 changes: 117 additions & 0 deletions src/services/mcp/__tests__/McpHub.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ describe("McpHub", () => {
ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
postMessageToWebview: vi.fn(),
getState: vi.fn().mockResolvedValue({ mcpEnabled: true }),
context: {
subscriptions: [],
workspaceState: {} as any,
Expand Down Expand Up @@ -877,6 +878,122 @@ describe("McpHub", () => {
})
})

describe("MCP global enable/disable", () => {
beforeEach(() => {
// Clear all mocks before each test
vi.clearAllMocks()
})

it("should not connect to servers when MCP is globally disabled", async () => {
// Mock provider with mcpEnabled: false
const disabledMockProvider = {
ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
postMessageToWebview: vi.fn(),
getState: vi.fn().mockResolvedValue({ mcpEnabled: false }),
context: mockProvider.context,
}

// Mock the config file read with a different server name to avoid conflicts
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify({
mcpServers: {
"disabled-test-server": {
command: "node",
args: ["test.js"],
},
},
}),
)

// Create a new McpHub instance with disabled MCP
const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider)

// Wait for initialization
await new Promise((resolve) => setTimeout(resolve, 100))

// Find the disabled-test-server
const disabledServer = mcpHub.connections.find((conn) => conn.server.name === "disabled-test-server")

// Verify that the server is tracked but not connected
expect(disabledServer).toBeDefined()
expect(disabledServer!.server.status).toBe("disconnected")
expect(disabledServer!.client).toBeNull()
expect(disabledServer!.transport).toBeNull()
})

it("should connect to servers when MCP is globally enabled", async () => {
// Clear all mocks
vi.clearAllMocks()

// Mock StdioClientTransport
const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js")
const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType<typeof vi.fn>

const mockTransport = {
start: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
stderr: {
on: vi.fn(),
},
onerror: null,
onclose: null,
}

StdioClientTransport.mockImplementation(() => mockTransport)

// Mock Client
const clientModule = await import("@modelcontextprotocol/sdk/client/index.js")
const Client = clientModule.Client as ReturnType<typeof vi.fn>

Client.mockImplementation(() => ({
connect: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
getInstructions: vi.fn().mockReturnValue("test instructions"),
request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }),
}))

// Mock provider with mcpEnabled: true
const enabledMockProvider = {
ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
postMessageToWebview: vi.fn(),
getState: vi.fn().mockResolvedValue({ mcpEnabled: true }),
context: mockProvider.context,
}

// Mock the config file read with a different server name
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify({
mcpServers: {
"enabled-test-server": {
command: "node",
args: ["test.js"],
},
},
}),
)

// Create a new McpHub instance with enabled MCP
const mcpHub = new McpHub(enabledMockProvider as unknown as ClineProvider)

// Wait for initialization
await new Promise((resolve) => setTimeout(resolve, 100))

// Find the enabled-test-server
const enabledServer = mcpHub.connections.find((conn) => conn.server.name === "enabled-test-server")

// Verify that the server is connected
expect(enabledServer).toBeDefined()
expect(enabledServer!.server.status).toBe("connected")
expect(enabledServer!.client).toBeDefined()
expect(enabledServer!.transport).toBeDefined()

// Verify StdioClientTransport was called
expect(StdioClientTransport).toHaveBeenCalled()
})
})

describe("Windows command wrapping", () => {
let StdioClientTransport: ReturnType<typeof vi.fn>
let Client: ReturnType<typeof vi.fn>
Expand Down
Loading