Skip to content

Commit 8c4b642

Browse files
McpHub.ts - Better Zod Types and Function Refactor (RooCodeInc#2497)
* refactor types and functions * fix timeout validation type error * changeset
1 parent 0204396 commit 8c4b642

File tree

2 files changed

+88
-77
lines changed

2 files changed

+88
-77
lines changed

.changeset/proud-games-do.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"claude-dev": patch
3+
---
4+
5+
Refactor types and functions in McpHub

src/services/mcp/McpHub.ts

Lines changed: 83 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -30,61 +30,43 @@ import { fileExistsAtPath } from "../../utils/fs"
3030
import { arePathsEqual } from "../../utils/path"
3131
import { secondsToMs } from "../../utils/time"
3232
import { GlobalFileNames } from "../../global-constants"
33+
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"
34+
3335
export type McpConnection = {
3436
server: McpServer
3537
client: Client
36-
transport: StdioClientTransport
38+
transport: StdioClientTransport | SSEClientTransport
3739
}
3840

3941
export type McpTransportType = "stdio" | "sse"
4042

41-
export type McpServerConfig = {
42-
transportType: McpTransportType
43-
autoApprove?: string[]
44-
disabled?: boolean
45-
timeout?: number
46-
} & (
47-
| {
48-
// Stdio specific
49-
transportType: "stdio"
50-
command: string
51-
args?: string[]
52-
env?: Record<string, string>
53-
}
54-
| {
55-
// SSE specific
56-
transportType: "sse"
57-
url: string
58-
headers?: Record<string, string>
59-
withCredentials?: boolean
60-
}
61-
)
43+
export type McpServerConfig = z.infer<typeof ServerConfigSchema>
6244

6345
const AutoApproveSchema = z.array(z.string()).default([])
6446

65-
const SseConfigSchema = z.object({
66-
transportType: z.literal("sse"),
67-
url: z.string().url(),
68-
headers: z.record(z.string()).optional(),
69-
withCredentials: z.boolean().optional().default(false),
47+
const BaseConfigSchema = z.object({
7048
autoApprove: AutoApproveSchema.optional(),
7149
disabled: z.boolean().optional(),
7250
timeout: z.number().min(MIN_MCP_TIMEOUT_SECONDS).optional().default(DEFAULT_MCP_TIMEOUT_SECONDS),
7351
})
7452

75-
const StdioConfigSchema = z.object({
53+
const SseConfigSchema = BaseConfigSchema.extend({
54+
url: z.string().url(),
55+
}).transform((config) => ({
56+
...config,
57+
transportType: "sse" as const,
58+
}))
59+
60+
const StdioConfigSchema = BaseConfigSchema.extend({
7661
command: z.string(),
7762
args: z.array(z.string()).optional(),
7863
env: z.record(z.string()).optional(),
79-
autoApprove: AutoApproveSchema.optional(),
80-
disabled: z.boolean().optional(),
81-
timeout: z.number().min(MIN_MCP_TIMEOUT_SECONDS).optional().default(DEFAULT_MCP_TIMEOUT_SECONDS),
82-
})
64+
}).transform((config) => ({
65+
...config,
66+
transportType: "stdio" as const,
67+
}))
8368

84-
const ServerConfigSchema = z.discriminatedUnion("transportType", [
85-
StdioConfigSchema.extend({ transportType: z.literal("stdio") }),
86-
SseConfigSchema,
87-
])
69+
const ServerConfigSchema = z.union([StdioConfigSchema, SseConfigSchema])
8870

8971
const McpSettingsSchema = z.object({
9072
mcpServers: z.record(ServerConfigSchema),
@@ -142,50 +124,68 @@ export class McpHub {
142124
return mcpSettingsFilePath
143125
}
144126

127+
private async readAndValidateMcpSettingsFile(): Promise<z.infer<typeof McpSettingsSchema> | undefined> {
128+
try {
129+
const settingsPath = await this.getMcpSettingsFilePath()
130+
const content = await fs.readFile(settingsPath, "utf-8")
131+
132+
let config: any
133+
134+
// Parse JSON file content
135+
try {
136+
config = JSON.parse(content)
137+
} catch (error) {
138+
vscode.window.showErrorMessage(
139+
"Invalid MCP settings format. Please ensure your settings follow the correct JSON format.",
140+
)
141+
return undefined
142+
}
143+
144+
// Validate against schema
145+
const result = McpSettingsSchema.safeParse(config)
146+
if (!result.success) {
147+
vscode.window.showErrorMessage("Invalid MCP settings schema.")
148+
return undefined
149+
}
150+
151+
return result.data
152+
} catch (error) {
153+
console.error("Failed to read MCP settings:", error)
154+
return undefined
155+
}
156+
}
157+
145158
private async watchMcpSettingsFile(): Promise<void> {
146159
const settingsPath = await this.getMcpSettingsFilePath()
147160
this.disposables.push(
148161
vscode.workspace.onDidSaveTextDocument(async (document) => {
149162
if (arePathsEqual(document.uri.fsPath, settingsPath)) {
150-
const content = await fs.readFile(settingsPath, "utf-8")
151-
const errorMessage =
152-
"Invalid MCP settings format. Please ensure your settings follow the correct JSON format."
153-
let config: any
154-
try {
155-
config = JSON.parse(content)
156-
} catch (error) {
157-
vscode.window.showErrorMessage(errorMessage)
158-
return
159-
}
160-
const result = McpSettingsSchema.safeParse(config)
161-
if (!result.success) {
162-
vscode.window.showErrorMessage(errorMessage)
163-
return
164-
}
165-
try {
166-
vscode.window.showInformationMessage("Updating MCP servers...")
167-
await this.updateServerConnections(result.data.mcpServers || {})
168-
vscode.window.showInformationMessage("MCP servers updated")
169-
} catch (error) {
170-
console.error("Failed to process MCP settings change:", error)
163+
const settings = await this.readAndValidateMcpSettingsFile()
164+
if (settings) {
165+
try {
166+
vscode.window.showInformationMessage("Updating MCP servers...")
167+
await this.updateServerConnections(settings.mcpServers)
168+
vscode.window.showInformationMessage("MCP servers updated")
169+
} catch (error) {
170+
console.error("Failed to process MCP settings change:", error)
171+
}
171172
}
172173
}
173174
}),
174175
)
175176
}
176177

177178
private async initializeMcpServers(): Promise<void> {
178-
try {
179-
const settingsPath = await this.getMcpSettingsFilePath()
180-
const content = await fs.readFile(settingsPath, "utf-8")
181-
const config = JSON.parse(content)
182-
await this.updateServerConnections(config.mcpServers || {})
183-
} catch (error) {
184-
console.error("Failed to initialize MCP servers:", error)
179+
const settings = await this.readAndValidateMcpSettingsFile()
180+
if (settings) {
181+
await this.updateServerConnections(settings.mcpServers)
185182
}
186183
}
187184

188-
private async connectToServer(name: string, config: StdioServerParameters): Promise<void> {
185+
private async connectToServer(
186+
name: string,
187+
config: z.infer<typeof StdioConfigSchema> | z.infer<typeof SseConfigSchema>,
188+
): Promise<void> {
189189
// Remove existing connection if it exists (should never happen, the connection should be deleted beforehand)
190190
this.connections = this.connections.filter((conn) => conn.server.name !== name)
191191

@@ -201,16 +201,22 @@ export class McpHub {
201201
},
202202
)
203203

204-
const transport = new StdioClientTransport({
205-
command: config.command,
206-
args: config.args,
207-
env: {
208-
...config.env,
209-
...(process.env.PATH ? { PATH: process.env.PATH } : {}),
210-
// ...(process.env.NODE_PATH ? { NODE_PATH: process.env.NODE_PATH } : {}),
211-
},
212-
stderr: "pipe", // necessary for stderr to be available
213-
})
204+
let transport: StdioClientTransport | SSEClientTransport
205+
206+
if (config.transportType === "sse") {
207+
return
208+
} else {
209+
transport = new StdioClientTransport({
210+
command: config.command,
211+
args: config.args,
212+
env: {
213+
...config.env,
214+
...(process.env.PATH ? { PATH: process.env.PATH } : {}),
215+
// ...(process.env.NODE_PATH ? { NODE_PATH: process.env.NODE_PATH } : {}),
216+
},
217+
stderr: "pipe", // necessary for stderr to be available
218+
})
219+
}
214220

215221
transport.onerror = async (error) => {
216222
console.error(`Transport error for "${name}":`, error)
@@ -372,7 +378,7 @@ export class McpHub {
372378
}
373379
}
374380

375-
async updateServerConnections(newServers: Record<string, any>): Promise<void> {
381+
async updateServerConnections(newServers: Record<string, McpServerConfig>): Promise<void> {
376382
this.isConnecting = true
377383
this.removeAllFileWatchers()
378384
const currentNames = new Set(this.connections.map((conn) => conn.server.name))
@@ -695,7 +701,7 @@ export class McpHub {
695701
public async updateServerTimeout(serverName: string, timeout: number): Promise<void> {
696702
try {
697703
// Validate timeout against schema
698-
const setConfigResult = StdioConfigSchema.shape.timeout.safeParse(timeout)
704+
const setConfigResult = BaseConfigSchema.shape.timeout.safeParse(timeout)
699705
if (!setConfigResult.success) {
700706
throw new Error(`Invalid timeout value: ${timeout}. Must be at minimum ${MIN_MCP_TIMEOUT_SECONDS} seconds.`)
701707
}

0 commit comments

Comments
 (0)