Skip to content

Commit 90cf89d

Browse files
committed
feat: implement discriminated union types for MCP connections and enhance connection handling
1 parent 78b784b commit 90cf89d

File tree

2 files changed

+736
-69
lines changed

2 files changed

+736
-69
lines changed

src/services/mcp/McpHub.ts

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,27 @@ import { fileExistsAtPath } from "../../utils/fs"
3333
import { arePathsEqual } from "../../utils/path"
3434
import { injectVariables } from "../../utils/config"
3535

36-
export type McpConnection = {
36+
// Discriminated union for connection states
37+
export type ConnectedMcpConnection = {
38+
type: "connected"
3739
server: McpServer
38-
client: Client | null
39-
transport: StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport | null
40+
client: Client
41+
transport: StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport
42+
}
43+
44+
export type DisconnectedMcpConnection = {
45+
type: "disconnected"
46+
server: McpServer
47+
client: null
48+
transport: null
49+
}
50+
51+
export type McpConnection = ConnectedMcpConnection | DisconnectedMcpConnection
52+
53+
// Enum for disable reasons
54+
export enum DisableReason {
55+
MCP_DISABLED = "mcpDisabled",
56+
SERVER_DISABLED = "serverDisabled",
4057
}
4158

4259
// Base configuration schema for common settings
@@ -559,20 +576,21 @@ export class McpHub {
559576
* @param config The server configuration
560577
* @param source The source of the server (global or project)
561578
* @param reason The reason for creating a placeholder (mcpDisabled or serverDisabled)
562-
* @returns A placeholder McpConnection object
579+
* @returns A placeholder DisconnectedMcpConnection object
563580
*/
564581
private createPlaceholderConnection(
565582
name: string,
566583
config: z.infer<typeof ServerConfigSchema>,
567584
source: "global" | "project",
568-
reason: "mcpDisabled" | "serverDisabled",
569-
): McpConnection {
585+
reason: DisableReason,
586+
): DisconnectedMcpConnection {
570587
return {
588+
type: "disconnected",
571589
server: {
572590
name,
573591
config: JSON.stringify(config),
574592
status: "disconnected",
575-
disabled: reason === "serverDisabled" ? true : config.disabled,
593+
disabled: reason === DisableReason.SERVER_DISABLED ? true : config.disabled,
576594
source,
577595
projectPath: source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined,
578596
errorHistory: [],
@@ -607,19 +625,22 @@ export class McpHub {
607625
const mcpEnabled = await this.isMcpEnabled()
608626
if (!mcpEnabled) {
609627
// Still create a connection object to track the server, but don't actually connect
610-
const connection = this.createPlaceholderConnection(name, config, source, "mcpDisabled")
628+
const connection = this.createPlaceholderConnection(name, config, source, DisableReason.MCP_DISABLED)
611629
this.connections.push(connection)
612630
return
613631
}
614632

615633
// Skip connecting to disabled servers
616634
if (config.disabled) {
617635
// Still create a connection object to track the server, but don't actually connect
618-
const connection = this.createPlaceholderConnection(name, config, source, "serverDisabled")
636+
const connection = this.createPlaceholderConnection(name, config, source, DisableReason.SERVER_DISABLED)
619637
this.connections.push(connection)
620638
return
621639
}
622640

641+
// Set up file watchers for enabled servers
642+
this.setupFileWatcher(name, config, source)
643+
623644
try {
624645
const client = new Client(
625646
{
@@ -793,7 +814,9 @@ export class McpHub {
793814
transport.start = async () => {}
794815
}
795816

796-
const connection: McpConnection = {
817+
// Create a connected connection
818+
const connection: ConnectedMcpConnection = {
819+
type: "connected",
797820
server: {
798821
name,
799822
config: JSON.stringify(configInjected),
@@ -886,8 +909,8 @@ export class McpHub {
886909
// Use the helper method to find the connection
887910
const connection = this.findConnection(serverName, source)
888911

889-
if (!connection || !connection.client) {
890-
throw new Error(`Server ${serverName} not found or not connected`)
912+
if (!connection || connection.type !== "connected") {
913+
return []
891914
}
892915

893916
const response = await connection.client.request({ method: "tools/list" }, ListToolsResultSchema)
@@ -941,7 +964,7 @@ export class McpHub {
941964
private async fetchResourcesList(serverName: string, source?: "global" | "project"): Promise<McpResource[]> {
942965
try {
943966
const connection = this.findConnection(serverName, source)
944-
if (!connection || !connection.client) {
967+
if (!connection || connection.type !== "connected") {
945968
return []
946969
}
947970
const response = await connection.client.request({ method: "resources/list" }, ListResourcesResultSchema)
@@ -958,7 +981,7 @@ export class McpHub {
958981
): Promise<McpResourceTemplate[]> {
959982
try {
960983
const connection = this.findConnection(serverName, source)
961-
if (!connection || !connection.client) {
984+
if (!connection || connection.type !== "connected") {
962985
return []
963986
}
964987
const response = await connection.client.request(
@@ -973,17 +996,18 @@ export class McpHub {
973996
}
974997

975998
async deleteConnection(name: string, source?: "global" | "project"): Promise<void> {
999+
// Clean up file watchers for this server
1000+
this.removeFileWatchersForServer(name)
1001+
9761002
// If source is provided, only delete connections from that source
9771003
const connections = source
9781004
? this.connections.filter((conn) => conn.server.name === name && conn.server.source === source)
9791005
: this.connections.filter((conn) => conn.server.name === name)
9801006

9811007
for (const connection of connections) {
9821008
try {
983-
if (connection.transport) {
1009+
if (connection.type === "connected") {
9841010
await connection.transport.close()
985-
}
986-
if (connection.client) {
9871011
await connection.client.close()
9881012
}
9891013
} catch (error) {
@@ -1136,6 +1160,14 @@ export class McpHub {
11361160
this.fileWatchers.clear()
11371161
}
11381162

1163+
private removeFileWatchersForServer(serverName: string) {
1164+
const watchers = this.fileWatchers.get(serverName)
1165+
if (watchers) {
1166+
watchers.forEach((watcher) => watcher.close())
1167+
this.fileWatchers.delete(serverName)
1168+
}
1169+
}
1170+
11391171
async restartConnection(serverName: string, source?: "global" | "project"): Promise<void> {
11401172
this.isConnecting = true
11411173

@@ -1349,13 +1381,16 @@ export class McpHub {
13491381

13501382
// If disabling a connected server, disconnect it
13511383
if (disabled && connection.server.status === "connected") {
1384+
// Clean up file watchers when disabling
1385+
this.removeFileWatchersForServer(serverName)
13521386
await this.deleteConnection(serverName, serverSource)
13531387
// Re-add as a disabled connection
13541388
await this.connectToServer(serverName, JSON.parse(connection.server.config), serverSource)
13551389
} else if (!disabled && connection.server.status === "disconnected") {
13561390
// If enabling a disabled server, connect it
13571391
const config = JSON.parse(connection.server.config)
13581392
await this.deleteConnection(serverName, serverSource)
1393+
// When re-enabling, file watchers will be set up in connectToServer
13591394
await this.connectToServer(serverName, config, serverSource)
13601395
} else if (connection.server.status === "connected") {
13611396
// Only refresh capabilities if connected
@@ -1539,7 +1574,7 @@ export class McpHub {
15391574

15401575
async readResource(serverName: string, uri: string, source?: "global" | "project"): Promise<McpResourceResponse> {
15411576
const connection = this.findConnection(serverName, source)
1542-
if (!connection || !connection.client) {
1577+
if (!connection || connection.type !== "connected") {
15431578
throw new Error(`No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}`)
15441579
}
15451580
if (connection.server.disabled) {
@@ -1563,7 +1598,7 @@ export class McpHub {
15631598
source?: "global" | "project",
15641599
): Promise<McpToolCallResponse> {
15651600
const connection = this.findConnection(serverName, source)
1566-
if (!connection || !connection.client) {
1601+
if (!connection || connection.type !== "connected") {
15671602
throw new Error(
15681603
`No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`,
15691604
)

0 commit comments

Comments
 (0)