diff --git a/src/inMemory.test.ts b/src/inMemory.test.ts index baf43446..d62d3d85 100644 --- a/src/inMemory.test.ts +++ b/src/inMemory.test.ts @@ -118,4 +118,59 @@ describe("InMemoryTransport", () => { await serverTransport.start(); expect(receivedMessage).toEqual(message); }); + + test("should handle double close idempotently", async () => { + let clientCloseCount = 0; + let serverCloseCount = 0; + + clientTransport.onclose = () => { + clientCloseCount++; + }; + + serverTransport.onclose = () => { + serverCloseCount++; + }; + + await clientTransport.close(); + await clientTransport.close(); // Second close should be idempotent + + expect(clientCloseCount).toBe(1); + expect(serverCloseCount).toBe(1); + }); + + test("should handle concurrent close from both sides", async () => { + let clientCloseCount = 0; + let serverCloseCount = 0; + + clientTransport.onclose = () => { + clientCloseCount++; + }; + + serverTransport.onclose = () => { + serverCloseCount++; + }; + + // Close both sides concurrently + await Promise.all([ + clientTransport.close(), + serverTransport.close() + ]); + + expect(clientCloseCount).toBe(1); + expect(serverCloseCount).toBe(1); + }); + + test("should reject send after close from either side", async () => { + await serverTransport.close(); + + // Both sides should reject sends + await expect( + clientTransport.send({ jsonrpc: "2.0", method: "test", id: 1 }) + ).rejects.toThrow("Not connected"); + + await expect( + serverTransport.send({ jsonrpc: "2.0", method: "test", id: 2 }) + ).rejects.toThrow("Not connected"); + }); + }); diff --git a/src/inMemory.ts b/src/inMemory.ts index 5dd6e81e..06a8db7c 100644 --- a/src/inMemory.ts +++ b/src/inMemory.ts @@ -13,6 +13,8 @@ interface QueuedMessage { export class InMemoryTransport implements Transport { private _otherTransport?: InMemoryTransport; private _messageQueue: QueuedMessage[] = []; + private _isClosed = false; + private _closePromise?: Promise; onclose?: () => void; onerror?: (error: Error) => void; @@ -39,10 +41,18 @@ export class InMemoryTransport implements Transport { } async close(): Promise { - const other = this._otherTransport; - this._otherTransport = undefined; - await other?.close(); - this.onclose?.(); + if (this._isClosed) return this._closePromise ?? Promise.resolve(); + + this._isClosed = true; + this._closePromise = (async () => { + const peer = this._otherTransport; + this._otherTransport = undefined; // Prevent infinite recursion + + this.onclose?.(); + await peer?.close(); + })(); + + return this._closePromise; } /** diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 7df190ba..178846d2 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -334,12 +334,24 @@ export abstract class Protocol< this._progressHandlers.clear(); this._pendingDebouncedNotifications.clear(); this._transport = undefined; + + // Abort all active request handlers + const requestHandlerAbortControllers = this._requestHandlerAbortControllers; + this._requestHandlerAbortControllers = new Map(); + this.onclose?.(); const error = new McpError(ErrorCode.ConnectionClosed, "Connection closed"); + + // Reject all pending response handlers (for outgoing requests) for (const handler of responseHandlers.values()) { handler(error); } + + // Abort all active request handlers (for incoming requests being processed) + for (const abortController of requestHandlerAbortControllers.values()) { + abortController.abort(error); + } } private _onerror(error: Error): void {