Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/core/webview/ClineProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,34 @@ export class ClineProvider
await this.postStateToWebview()
})

// Initialize MCP Hub through the singleton manager
McpServerManager.getInstance(this.context, this)
.then((hub) => {
this.mcpHub = hub
this.mcpHub.registerClient()
})
.catch((error) => {
this.log(`Failed to initialize MCP Hub: ${error}`)
})
// Initialize MCP Hub through the singleton manager only if mcpEnabled
this.initializeMcpIfEnabled()

this.marketplaceManager = new MarketplaceManager(this.context)
}

private async initializeMcpIfEnabled() {
try {
// Get the mcpEnabled setting from global state
const mcpEnabled = this.contextProxy.getValue("mcpEnabled") ?? true

if (mcpEnabled) {
// Initialize MCP Hub through the singleton manager
const hub = await McpServerManager.getInstance(this.context, this, true)
this.mcpHub = hub
this.mcpHub.registerClient()
} else {
this.log("MCP is disabled, skipping MCP Hub initialization")
// Still create the hub instance but don't initialize servers
const hub = await McpServerManager.getInstance(this.context, this, false)
this.mcpHub = hub
this.mcpHub.registerClient()
}
} catch (error) {
this.log(`Failed to initialize MCP Hub: ${error}`)
}
}

// Adds a new Cline instance to clineStack, marking the start of a new task.
// The instance is pushed to the top of the stack (LIFO order).
// When the task is completed, the top instance is removed, reactivating the previous task.
Expand Down
11 changes: 11 additions & 0 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,17 @@ export const webviewMessageHandler = async (
case "mcpEnabled":
const mcpEnabled = message.bool ?? true
await updateGlobalState("mcpEnabled", mcpEnabled)

// Enable or disable MCP servers based on the setting
const mcpHubInstance = provider.getMcpHub()
if (mcpHubInstance) {
if (mcpEnabled) {
await mcpHubInstance.enableServers()
} else {
await mcpHubInstance.disableServers()
}
}

await provider.postStateToWebview()
break
case "enableMcpServerCreation":
Expand Down
43 changes: 40 additions & 3 deletions src/services/mcp/McpHub.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,16 @@ export class McpHub {
private refCount: number = 0 // Reference counter for active clients
private configChangeDebounceTimers: Map<string, NodeJS.Timeout> = new Map()

constructor(provider: ClineProvider) {
constructor(provider: ClineProvider, autoInitialize: boolean = true) {
this.providerRef = new WeakRef(provider)
this.watchMcpSettingsFile()
this.watchProjectMcpFile().catch(console.error)
this.setupWorkspaceFoldersWatcher()
this.initializeGlobalMcpServers()
this.initializeProjectMcpServers()

if (autoInitialize) {
this.initializeGlobalMcpServers()
this.initializeProjectMcpServers()
}
}
/**
* Registers a client (e.g., ClineProvider) using this hub.
Expand Down Expand Up @@ -1643,4 +1646,38 @@ export class McpHub {
}
this.disposables.forEach((d) => d.dispose())
}

/**
* Enable MCP servers by initializing them
*/
async enableServers(): Promise<void> {
if (this.connections.length > 0) {
console.log("MCP servers are already initialized")
return
}

console.log("Enabling MCP servers...")
await this.initializeGlobalMcpServers()
await this.initializeProjectMcpServers()
await this.notifyWebviewOfServerChanges()
}

/**
* Disable MCP servers by disconnecting all connections
*/
async disableServers(): Promise<void> {
console.log("Disabling MCP servers...")

// Disconnect all servers
const allConnections = [...this.connections]
for (const conn of allConnections) {
await this.deleteConnection(conn.server.name, conn.server.source)
}

// Clear connections array
this.connections = []

// Notify webview of changes
await this.notifyWebviewOfServerChanges()
}
}
8 changes: 6 additions & 2 deletions src/services/mcp/McpServerManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ export class McpServerManager {
* Creates a new instance if one doesn't exist.
* Thread-safe implementation using a promise-based lock.
*/
static async getInstance(context: vscode.ExtensionContext, provider: ClineProvider): Promise<McpHub> {
static async getInstance(
context: vscode.ExtensionContext,
provider: ClineProvider,
autoInitialize: boolean = true,
): Promise<McpHub> {
// Register the provider
this.providers.add(provider)

Expand All @@ -36,7 +40,7 @@ export class McpServerManager {
try {
// Double-check instance in case it was created while we were waiting
if (!this.instance) {
this.instance = new McpHub(provider)
this.instance = new McpHub(provider, autoInitialize)
// Store a unique identifier in global state to track the primary instance
await context.globalState.update(this.GLOBAL_STATE_KEY, Date.now().toString())
}
Expand Down
153 changes: 153 additions & 0 deletions src/services/mcp/__tests__/McpHub.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ vi.mock("vscode", () => ({
Disposable: {
from: vi.fn(),
},
Uri: {
file: vi.fn((path) => ({
scheme: "file",
authority: "",
path,
query: "",
fragment: "",
fsPath: path,
with: vi.fn(),
toJSON: vi.fn(),
})),
},
RelativePattern: vi.fn((base, pattern) => ({ base, pattern })),
}))
vi.mock("fs/promises")
vi.mock("../../../core/webview/ClineProvider")
Expand Down Expand Up @@ -1226,4 +1239,144 @@ describe("McpHub", () => {
)
})
})

describe("Conditional initialization", () => {
it("should not initialize servers when autoInitialize is false", async () => {
// Clear existing mocks
vi.clearAllMocks()

// Create McpHub with autoInitialize = false
const mcpHubNoInit = new McpHub(mockProvider as ClineProvider, false)

// Verify no connections were created
expect(mcpHubNoInit.connections.length).toBe(0)

// Verify fs.readFile was not called for server initialization
expect(fs.readFile).not.toHaveBeenCalled()
})

it("should initialize servers when autoInitialize is true", async () => {
// Clear existing mocks
vi.clearAllMocks()

// Mock the config file read
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify({
mcpServers: {
"test-server": {
type: "stdio",
command: "node",
args: ["test.js"],
},
},
}),
)

// Create McpHub with autoInitialize = true
const mcpHubInit = new McpHub(mockProvider as ClineProvider, true)

// Wait a bit for async initialization
await new Promise((resolve) => setTimeout(resolve, 100))

// Verify fs.readFile was called for initialization
expect(fs.readFile).toHaveBeenCalled()
})

it("should enable servers after creation when enableServers is called", async () => {
// Clear existing mocks
vi.clearAllMocks()

// Mock the config file read
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify({
mcpServers: {
"test-server": {
type: "stdio",
command: "node",
args: ["test.js"],
},
},
}),
)

// Mock StdioClientTransport
const mockTransport = {
start: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
stderr: {
on: vi.fn(),
},
onerror: null,
onclose: null,
}

const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js")
const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType<typeof vi.fn>
StdioClientTransport.mockImplementation(() => mockTransport)

// Mock Client
const clientModule = await import("@modelcontextprotocol/sdk/client/index.js")
const Client = clientModule.Client as ReturnType<typeof vi.fn>
Client.mockImplementation(() => ({
connect: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
getInstructions: vi.fn().mockReturnValue("test instructions"),
request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }),
}))

// Create McpHub without auto-initialization
const mcpHubNoInit = new McpHub(mockProvider as ClineProvider, false)

// Verify no connections initially
expect(mcpHubNoInit.connections.length).toBe(0)

// Enable servers
await mcpHubNoInit.enableServers()

// Verify connections were created
expect(mcpHubNoInit.connections.length).toBeGreaterThan(0)
})

it("should disable all servers when disableServers is called", async () => {
// Create McpHub with some connections
mcpHub.connections = [
{
server: {
name: "test-server-1",
source: "global",
} as any,
client: {
close: vi.fn().mockResolvedValue(undefined),
} as any,
transport: {
close: vi.fn().mockResolvedValue(undefined),
} as any,
},
{
server: {
name: "test-server-2",
source: "project",
} as any,
client: {
close: vi.fn().mockResolvedValue(undefined),
} as any,
transport: {
close: vi.fn().mockResolvedValue(undefined),
} as any,
},
]

// Disable servers
await mcpHub.disableServers()

// Verify all connections were closed
expect(mcpHub.connections.length).toBe(0)
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
expect.objectContaining({
type: "mcpServers",
mcpServers: [],
}),
)
})
})
})