diff --git a/packages/opencode/src/mcp/index.ts b/packages/opencode/src/mcp/index.ts index aca0c663152..cfc55c095d6 100644 --- a/packages/opencode/src/mcp/index.ts +++ b/packages/opencode/src/mcp/index.ts @@ -55,6 +55,91 @@ export namespace MCP { type MCPClient = Client + // Track local transports for process cleanup + type LocalTransportInfo = { + transport: StdioClientTransport + } + + // Check if a process is still running + function isProcessAlive(pid: number): boolean { + try { + // Sending signal 0 checks if the process exists without actually killing it + process.kill(pid, 0) + return true + } catch { + return false + } + } + + // Timeout for kill operations to prevent hanging + const KILL_TIMEOUT_MS = 5000 + // Grace period before escalating SIGTERM to SIGKILL (aligns with Shell.killTree) + const SIGKILL_DELAY_MS = 200 + + // Helper to forcefully kill a process by PID + // Handles both Windows (taskkill) and Unix (SIGTERM→SIGKILL) platforms + // Note: We can't reuse Shell.killTree here because it requires a ChildProcess object, + // but MCP SDK's StdioClientTransport only exposes the PID, not the underlying process. + async function forceKillProcess(pid: number, name: string): Promise { + // Check if process is still alive before attempting to kill + if (!isProcessAlive(pid)) { + log.debug("MCP process already terminated", { name, pid }) + return + } + + log.info("force killing MCP server process", { name, pid }) + try { + if (process.platform === "win32") { + const { spawn } = await import("child_process") + const killer = spawn("taskkill", ["/pid", String(pid), "/f", "/t"], { stdio: "ignore" }) + await new Promise((resolve) => { + const timeout = setTimeout(() => { + log.debug("taskkill timed out, killing taskkill process", { name, pid }) + killer.kill("SIGKILL") + resolve() + }, KILL_TIMEOUT_MS) + killer.once("exit", (code) => { + clearTimeout(timeout) + if (code !== 0) { + log.debug("taskkill exited with non-zero code", { name, pid, code }) + } + resolve() + }) + killer.once("error", (error) => { + clearTimeout(timeout) + log.debug("taskkill process error", { name, pid, error }) + resolve() + }) + }) + } else { + // Unix: Try graceful SIGTERM first, then escalate to SIGKILL + // This aligns with Shell.killTree behavior + const tryKill = (signal: NodeJS.Signals) => { + try { + process.kill(-pid, signal) // Process group first + } catch { + try { + process.kill(pid, signal) // Individual process fallback + } catch { + // Process already gone + } + } + } + + tryKill("SIGTERM") + await Bun.sleep(SIGKILL_DELAY_MS) + + // If still alive, escalate to SIGKILL + if (isProcessAlive(pid)) { + log.debug("process still alive after SIGTERM, sending SIGKILL", { name, pid }) + tryKill("SIGKILL") + } + } + } catch (error) { + log.debug("failed to force kill MCP process", { name, pid, error }) + } + } + export const Status = z .discriminatedUnion("status", [ z @@ -159,6 +244,7 @@ export namespace MCP { const config = cfg.mcp ?? {} const clients: Record = {} const status: Record = {} + const localTransports: Record = {} await Promise.all( Object.entries(config).map(async ([key, mcp]) => { @@ -181,23 +267,38 @@ export namespace MCP { if (result.mcpClient) { clients[key] = result.mcpClient } + if (result.localTransport) { + localTransports[key] = result.localTransport + } }), ) return { status, clients, + localTransports, } }, async (state) => { + // First, try to gracefully close all clients with a timeout + const closeTimeout = 2000 await Promise.all( Object.values(state.clients).map((client) => - client.close().catch((error) => { - log.error("Failed to close MCP client", { - error, - }) + withTimeout(client.close(), closeTimeout).catch((error) => { + log.error("Failed to close MCP client", { error }) }), ), ) + + // Then forcefully kill any remaining local MCP server processes + // This handles processes that don't respond to the SDK's abort signal + // (e.g., docker containers without --init, hung processes) + for (const [name, info] of Object.entries(state.localTransports)) { + const pid = info.transport.pid + if (pid) { + await forceKillProcess(pid, name) + } + } + pendingOAuthTransports.clear() }, ) @@ -268,6 +369,9 @@ export namespace MCP { } s.clients[name] = result.mcpClient s.status[name] = result.status + if (result.localTransport) { + s.localTransports[name] = result.localTransport + } return { status: s.status, @@ -279,12 +383,14 @@ export namespace MCP { log.info("mcp server disabled", { key }) return { mcpClient: undefined, + localTransport: undefined, status: { status: "disabled" as const }, } } log.info("found", { key, type: mcp.type }) let mcpClient: MCPClient | undefined + let localTransport: LocalTransportInfo | undefined let status: Status | undefined = undefined if (mcp.type === "remote") { @@ -406,6 +512,9 @@ export namespace MCP { }, }) + // Track transport before connection so we can clean up on failure + localTransport = { transport } + const connectTimeout = mcp.timeout ?? DEFAULT_TIMEOUT try { const client = new Client({ @@ -425,6 +534,12 @@ export namespace MCP { cwd, error: error instanceof Error ? error.message : String(error), }) + // Force kill the process if connection failed + const pid = transport.pid + if (pid) { + await forceKillProcess(pid, key) + } + localTransport = undefined status = { status: "failed" as const, error: error instanceof Error ? error.message : String(error), @@ -442,6 +557,7 @@ export namespace MCP { if (!mcpClient) { return { mcpClient: undefined, + localTransport: undefined, status, } } @@ -456,12 +572,20 @@ export namespace MCP { error, }) }) + // Force kill the process if listTools failed (graceful close may not work) + if (localTransport) { + const pid = localTransport.transport.pid + if (pid) { + await forceKillProcess(pid, key) + } + } status = { status: "failed", error: "Failed to get tools", } return { mcpClient: undefined, + localTransport: undefined, status: { status: "failed" as const, error: "Failed to get tools", @@ -472,6 +596,7 @@ export namespace MCP { log.info("create() successfully created client", { key, toolCount: result.tools.length }) return { mcpClient, + localTransport, status, } } @@ -525,17 +650,33 @@ export namespace MCP { if (result.mcpClient) { s.clients[name] = result.mcpClient } + if (result.localTransport) { + s.localTransports[name] = result.localTransport + } } export async function disconnect(name: string) { const s = await state() const client = s.clients[name] + const transportInfo = s.localTransports[name] + if (client) { - await client.close().catch((error) => { + // Try graceful close first + await withTimeout(client.close(), 2000).catch((error) => { log.error("Failed to close MCP client", { name, error }) }) delete s.clients[name] } + + // Force kill the process if it's a local MCP server + if (transportInfo) { + const pid = transportInfo.transport.pid + if (pid) { + await forceKillProcess(pid, name) + } + delete s.localTransports[name] + } + s.status[name] = { status: "disabled" } } diff --git a/packages/opencode/test/mcp/process-cleanup.test.ts b/packages/opencode/test/mcp/process-cleanup.test.ts new file mode 100644 index 00000000000..2aa740e5f23 --- /dev/null +++ b/packages/opencode/test/mcp/process-cleanup.test.ts @@ -0,0 +1,151 @@ +import { test, expect, beforeEach, mock } from "bun:test" + +// Track spawned transports for verification +let mockPid = 12345 +const mockTransports: Array<{ + pid: number | null +}> = [] + +// Track client.close() calls +const clientCloseCallCount = { value: 0 } + +mock.module("@modelcontextprotocol/sdk/client/stdio.js", () => ({ + StdioClientTransport: class MockStdioTransport { + private _pid: number | null = mockPid++ + + constructor(_options: unknown) { + mockTransports.push({ pid: this._pid }) + } + + get pid() { + return this._pid + } + + async start() {} + async close() {} + + onmessage?: (message: unknown) => void + onerror?: (error: Error) => void + onclose?: () => void + + send(_message: unknown) { + return Promise.resolve() + } + }, +})) + +mock.module("@modelcontextprotocol/sdk/client/index.js", () => ({ + Client: class MockClient { + async connect(_transport: unknown) {} + async close() { + clientCloseCallCount.value++ + } + async listTools() { + return { tools: [] } + } + setNotificationHandler() {} + }, +})) + +beforeEach(() => { + mockTransports.length = 0 + mockPid = 12345 + clientCloseCallCount.value = 0 +}) + +// Import after mocking +const { MCP } = await import("../../src/mcp/index") +const { Instance } = await import("../../src/project/instance") +const { tmpdir } = await import("../fixture/fixture") + +test("local MCP server transport is created with PID tracking", async () => { + await using tmp = await tmpdir() + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await MCP.add("test-server", { + type: "local", + command: ["echo", "test"], + }) + + // Transport should be created with a PID + expect(mockTransports.length).toBe(1) + expect(mockTransports[0].pid).toBe(12345) + }, + }) +}) + +test("multiple local servers each get unique PID tracking", async () => { + await using tmp = await tmpdir() + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await MCP.add("server-1", { type: "local", command: ["cmd1"] }) + await MCP.add("server-2", { type: "local", command: ["cmd2"] }) + await MCP.add("server-3", { type: "local", command: ["cmd3"] }) + + expect(mockTransports.length).toBe(3) + + // Each has unique PID + const pids = mockTransports.map((t) => t.pid) + expect(new Set(pids).size).toBe(3) + }, + }) +}) + +test("MCP.disconnect calls client.close()", async () => { + await using tmp = await tmpdir() + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await MCP.add("test-server", { + type: "local", + command: ["echo", "test"], + }) + + expect(clientCloseCallCount.value).toBe(0) + + await MCP.disconnect("test-server") + + // client.close() should have been called + expect(clientCloseCallCount.value).toBe(1) + }, + }) +}) + +test("MCP.connect also creates transport with PID tracking", async () => { + await using tmp = await tmpdir({ + init: async (dir) => { + await Bun.write( + `${dir}/opencode.json`, + JSON.stringify({ + $schema: "https://opencode.ai/config.json", + mcp: { + "configured-server": { + type: "local", + command: ["echo", "test"], + enabled: false, + }, + }, + }), + ) + }, + }) + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + // Initially no transports (server is disabled) + expect(mockTransports.length).toBe(0) + + // Connect creates transport + await MCP.connect("configured-server") + + expect(mockTransports.length).toBe(1) + expect(mockTransports[0].pid).toBe(12345) + }, + }) +})