Skip to content

Commit e8f99e2

Browse files
committed
Add tests for toggleAllServersDisabled and restartAllMcpServers methods in McpHub
1 parent a7a96fb commit e8f99e2

File tree

1 file changed

+283
-4
lines changed

1 file changed

+283
-4
lines changed

src/services/mcp/__tests__/McpHub.test.ts

Lines changed: 283 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import type { ExtensionContext, Uri } from "vscode"
44
import { ServerConfigSchema } from "../McpHub"
55

66
const fs = require("fs/promises")
7-
const { McpHub } = require("../McpHub")
7+
const { McpHub } = jest.requireActual("../McpHub") // Use requireActual to get the real module
8+
9+
let originalConsoleError: typeof console.error = console.error // Store original console methods globally
810

911
jest.mock("vscode", () => ({
1012
workspace: {
@@ -30,17 +32,39 @@ jest.mock("vscode", () => ({
3032
jest.mock("fs/promises")
3133
jest.mock("../../../core/webview/ClineProvider")
3234

35+
// Mock the McpHub module itself
36+
jest.mock("../McpHub", () => {
37+
const originalModule = jest.requireActual("../McpHub")
38+
return {
39+
__esModule: true,
40+
...originalModule,
41+
McpHub: jest.fn().mockImplementation((provider) => {
42+
const instance = new originalModule.McpHub(provider)
43+
// Spy on private methods
44+
jest.spyOn(instance, "updateServerConfig" as any).mockResolvedValue(undefined)
45+
jest.spyOn(instance, "findConnection" as any).mockReturnValue({ server: { disabled: false } } as any)
46+
jest.spyOn(instance, "initializeMcpServers" as any).mockResolvedValue(undefined)
47+
jest.spyOn(instance, "notifyWebviewOfServerChanges" as any).mockResolvedValue(undefined)
48+
jest.spyOn(instance, "restartConnection" as any).mockResolvedValue(undefined)
49+
jest.spyOn(instance, "showErrorMessage" as any).mockImplementation(jest.fn())
50+
jest.spyOn(instance, "getAllServers" as any).mockReturnValue([
51+
{ name: "server1", source: "global", disabled: false, config: "{}", status: "connected" },
52+
{ name: "server2", source: "project", disabled: false, config: "{}", status: "connected" },
53+
])
54+
return instance
55+
}),
56+
}
57+
})
58+
3359
describe("McpHub", () => {
3460
let mcpHub: McpHubType
3561
let mockProvider: Partial<ClineProvider>
3662

37-
// Store original console methods
38-
const originalConsoleError = console.error
39-
4063
beforeEach(() => {
4164
jest.clearAllMocks()
4265

4366
// Mock console.error to suppress error messages during tests
67+
originalConsoleError = console.error // Store original before mocking
4468
console.error = jest.fn()
4569

4670
const mockUri: Uri = {
@@ -317,6 +341,136 @@ describe("McpHub", () => {
317341
})
318342
})
319343

344+
describe("toggleAllServersDisabled", () => {
345+
it("should disable all servers when passed true", async () => {
346+
const mockConnections: McpConnection[] = [
347+
{
348+
server: {
349+
name: "server1",
350+
config: "{}",
351+
status: "connected",
352+
disabled: false,
353+
},
354+
client: {} as any,
355+
transport: {} as any,
356+
},
357+
{
358+
server: {
359+
name: "server2",
360+
config: "{}",
361+
status: "connected",
362+
disabled: false,
363+
},
364+
client: {} as any,
365+
transport: {} as any,
366+
},
367+
]
368+
mcpHub.connections = mockConnections
369+
370+
// Mock fs.readFile to return a config with both servers enabled
371+
;(fs.readFile as jest.Mock).mockResolvedValueOnce(
372+
JSON.stringify({
373+
mcpServers: {
374+
server1: { disabled: false },
375+
server2: { disabled: false },
376+
},
377+
}),
378+
)
379+
380+
await mcpHub.toggleAllServersDisabled(true)
381+
382+
// Verify that both servers are now disabled in the connections
383+
expect(mcpHub.connections[0].server.disabled).toBe(true)
384+
expect(mcpHub.connections[1].server.disabled).toBe(true)
385+
386+
// Mock fs.readFile and fs.writeFile to track config changes
387+
let currentConfig = JSON.stringify({
388+
mcpServers: {
389+
server1: { disabled: false },
390+
server2: { disabled: false },
391+
},
392+
})
393+
;(fs.readFile as jest.Mock).mockImplementation(async () => currentConfig)
394+
;(fs.writeFile as jest.Mock).mockImplementation(async (path, data) => {
395+
currentConfig = data
396+
})
397+
398+
await mcpHub.toggleAllServersDisabled(true)
399+
400+
// Verify that both servers are now disabled in the connections
401+
expect(mcpHub.connections[0].server.disabled).toBe(true)
402+
expect(mcpHub.connections[1].server.disabled).toBe(true)
403+
404+
// Verify that fs.writeFile was called to persist the changes
405+
const writtenConfig = JSON.parse(currentConfig)
406+
expect(writtenConfig.mcpServers.server1.disabled).toBe(true)
407+
expect(writtenConfig.mcpServers.server2.disabled).toBe(true)
408+
409+
// Verify that postMessageToWebview was called
410+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
411+
expect.objectContaining({
412+
type: "mcpServers",
413+
}),
414+
)
415+
})
416+
417+
it("should enable all servers when passed false", async () => {
418+
const mockConnections: McpConnection[] = [
419+
{
420+
server: {
421+
name: "server1",
422+
config: "{}",
423+
status: "connected",
424+
disabled: true,
425+
},
426+
client: {} as any,
427+
transport: {} as any,
428+
},
429+
{
430+
server: {
431+
name: "server2",
432+
config: "{}",
433+
status: "connected",
434+
disabled: true,
435+
},
436+
client: {} as any,
437+
transport: {} as any,
438+
},
439+
]
440+
mcpHub.connections = mockConnections
441+
442+
// Mock fs.readFile to return a config with both servers disabled
443+
let currentConfig = JSON.stringify({
444+
mcpServers: {
445+
server1: { disabled: true },
446+
server2: { disabled: true },
447+
},
448+
})
449+
;(fs.readFile as jest.Mock).mockImplementation(async () => currentConfig)
450+
;(fs.writeFile as jest.Mock).mockImplementation(async (path, data) => {
451+
currentConfig = data
452+
})
453+
454+
await mcpHub.toggleAllServersDisabled(false)
455+
456+
// Verify that both servers are now enabled in the connections
457+
expect(mcpHub.connections[0].server.disabled).toBe(false)
458+
expect(mcpHub.connections[1].server.disabled).toBe(false)
459+
460+
// Verify that fs.writeFile was called to persist the changes
461+
const writtenConfig = JSON.parse(currentConfig)
462+
expect(writtenConfig.mcpServers.server1.disabled).toBe(false)
463+
expect(writtenConfig.mcpServers.server2.disabled).toBe(false)
464+
465+
// Verify that postMessageToWebview was called
466+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
467+
expect.objectContaining({
468+
type: "mcpServers",
469+
}),
470+
)
471+
})
472+
})
473+
320474
describe("callTool", () => {
321475
it("should execute tool successfully", async () => {
322476
// Mock the connection with a minimal client implementation
@@ -560,4 +714,129 @@ describe("McpHub", () => {
560714
})
561715
})
562716
})
717+
718+
describe("restartAllMcpServers", () => {
719+
let mcpHub: McpHubType
720+
let mockProvider: Partial<ClineProvider>
721+
722+
beforeEach(() => {
723+
jest.clearAllMocks()
724+
// Mock console.error to suppress error messages during tests
725+
originalConsoleError = console.error // Store original before mocking
726+
console.error = jest.fn()
727+
728+
const mockUri: Uri = {
729+
scheme: "file",
730+
authority: "",
731+
path: "/test/path",
732+
query: "",
733+
fragment: "",
734+
fsPath: "/test/path",
735+
with: jest.fn(),
736+
toJSON: jest.fn(),
737+
}
738+
739+
mockProvider = {
740+
ensureSettingsDirectoryExists: jest.fn().mockResolvedValue("/mock/settings/path"),
741+
ensureMcpServersDirectoryExists: jest.fn().mockResolvedValue("/mock/settings/path"),
742+
postMessageToWebview: jest.fn(),
743+
context: {
744+
subscriptions: [],
745+
workspaceState: {} as any,
746+
globalState: {} as any,
747+
secrets: {} as any,
748+
extensionUri: mockUri,
749+
extensionPath: "/test/path",
750+
storagePath: "/test/storage",
751+
globalStoragePath: "/test/global-storage",
752+
environmentVariableCollection: {} as any,
753+
extension: {
754+
id: "test-extension",
755+
extensionUri: mockUri,
756+
extensionPath: "/test/path",
757+
extensionKind: 1,
758+
isActive: true,
759+
packageJSON: {
760+
version: "1.0.0",
761+
},
762+
activate: jest.fn(),
763+
exports: undefined,
764+
} as any,
765+
asAbsolutePath: (path: string) => path,
766+
storageUri: mockUri,
767+
globalStorageUri: mockUri,
768+
logUri: mockUri,
769+
extensionMode: 1,
770+
logPath: "/test/path",
771+
languageModelAccessInformation: {} as any,
772+
} as ExtensionContext,
773+
}
774+
775+
// Mock fs.readFile for initial settings
776+
;(fs.readFile as jest.Mock).mockResolvedValue(
777+
JSON.stringify({
778+
mcpServers: {
779+
"test-server": {
780+
type: "stdio",
781+
command: "node",
782+
args: ["test.js"],
783+
alwaysAllow: ["allowed-tool"],
784+
},
785+
},
786+
}),
787+
)
788+
789+
mcpHub = new McpHub(mockProvider as ClineProvider)
790+
jest.spyOn(mcpHub as any, "showErrorMessage").mockImplementation(jest.fn())
791+
792+
// Mock internal methods
793+
jest.spyOn(mcpHub, "getAllServers" as any).mockReturnValue([
794+
{ name: "server1", source: "global", disabled: false },
795+
{ name: "server2", source: "project", disabled: true }, // Disabled server
796+
{ name: "server3", source: "global", disabled: false },
797+
])
798+
jest.spyOn(mcpHub, "restartConnection" as any).mockResolvedValue(undefined)
799+
jest.spyOn(mcpHub as any, "notifyWebviewOfServerChanges").mockResolvedValue(undefined)
800+
})
801+
802+
afterEach(() => {
803+
// Restore original console methods
804+
console.error = originalConsoleError
805+
jest.restoreAllMocks() // Clean up spies
806+
})
807+
808+
it("should restart only active servers", async () => {
809+
await mcpHub.restartAllMcpServers()
810+
811+
expect(mcpHub.getAllServers).toHaveBeenCalled()
812+
expect(mcpHub.restartConnection).toHaveBeenCalledTimes(2) // Only server1 and server3 should be restarted
813+
expect(mcpHub.restartConnection).toHaveBeenCalledWith("server1", "global")
814+
expect(mcpHub.restartConnection).not.toHaveBeenCalledWith("server2", "project")
815+
expect(mcpHub.restartConnection).toHaveBeenCalledWith("server3", "global")
816+
expect((mcpHub as any).notifyWebviewOfServerChanges).toHaveBeenCalledTimes(1)
817+
})
818+
819+
it("should call showErrorMessage if a restart fails", async () => {
820+
jest.spyOn(mcpHub, "restartConnection" as any).mockRejectedValueOnce(new Error("Restart failed"))
821+
822+
await mcpHub.restartAllMcpServers()
823+
824+
expect(mcpHub.getAllServers).toHaveBeenCalled()
825+
expect(mcpHub.restartConnection).toHaveBeenCalledTimes(2) // Only active servers are attempted to restart
826+
expect((mcpHub as any).showErrorMessage).toHaveBeenCalledTimes(1)
827+
expect((mcpHub as any).showErrorMessage).toHaveBeenCalledWith(
828+
"Failed to restart MCP server server1",
829+
expect.any(Error),
830+
)
831+
expect((mcpHub as any).notifyWebviewOfServerChanges).toHaveBeenCalledTimes(1)
832+
})
833+
834+
it("should call notifyWebviewOfServerChanges even if some restarts fail", async () => {
835+
jest.spyOn(mcpHub, "restartConnection").mockRejectedValue(new Error("Restart failed"))
836+
837+
await mcpHub.restartAllMcpServers()
838+
839+
expect((mcpHub as any).notifyWebviewOfServerChanges).toHaveBeenCalledTimes(1)
840+
})
841+
})
563842
})

0 commit comments

Comments
 (0)