Skip to content

Commit a953d4e

Browse files
committed
refactor(sse): streamline SSE connection management logic
1 parent b6de268 commit a953d4e

File tree

2 files changed

+162
-119
lines changed

2 files changed

+162
-119
lines changed

package-lock.json

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/transports/sse/server.ts

Lines changed: 160 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
import { randomUUID } from "node:crypto"
2-
import { IncomingMessage, Server as HttpServer, ServerResponse, createServer } from "node:http"
3-
import { JSONRPCMessage, ClientRequest } from "@modelcontextprotocol/sdk/types.js"
4-
import contentType from "content-type"
5-
import getRawBody from "raw-body"
6-
import { APIKeyAuthProvider } from "../../auth/providers/apikey.js"
7-
import { DEFAULT_AUTH_ERROR } from "../../auth/types.js"
8-
import { AbstractTransport } from "../base.js"
9-
import { DEFAULT_SSE_CONFIG, SSETransportConfig, SSETransportConfigInternal, DEFAULT_CORS_CONFIG, CORSConfig } from "./types.js"
10-
import { logger } from "../../core/Logger.js"
11-
import { getRequestHeader, setResponseHeaders } from "../../utils/headers.js"
1+
import { randomUUID } from "node:crypto";
2+
import { IncomingMessage, Server as HttpServer, ServerResponse, createServer } from "node:http";
3+
import { JSONRPCMessage, ClientRequest } from "@modelcontextprotocol/sdk/types.js";
4+
import contentType from "content-type";
5+
import getRawBody from "raw-body";
6+
import { APIKeyAuthProvider } from "../../auth/providers/apikey.js";
7+
import { DEFAULT_AUTH_ERROR } from "../../auth/types.js";
8+
import { AbstractTransport } from "../base.js";
9+
import { DEFAULT_SSE_CONFIG, SSETransportConfig, SSETransportConfigInternal, DEFAULT_CORS_CONFIG, CORSConfig } from "./types.js";
10+
import { logger } from "../../core/Logger.js";
11+
import { getRequestHeader, setResponseHeaders } from "../../utils/headers.js";
1212
import { PING_SSE_MESSAGE } from "../utils/ping-message.js";
1313

14-
interface ExtendedIncomingMessage extends IncomingMessage {
15-
body?: ClientRequest
16-
}
1714

1815
const SSE_HEADERS = {
1916
"Content-Type": "text/event-stream",
@@ -25,14 +22,14 @@ export class SSEServerTransport extends AbstractTransport {
2522
readonly type = "sse"
2623

2724
private _server?: HttpServer
28-
private _sseResponse?: ServerResponse
29-
private _sessionId: string
25+
private _connections: Map<string, { res: ServerResponse, intervalId: NodeJS.Timeout }> // Map<connectionId, { res: ServerResponse, intervalId: NodeJS.Timeout }>
26+
private _sessionId: string // Server instance ID
3027
private _config: SSETransportConfigInternal
31-
private _keepAliveInterval?: NodeJS.Timeout
3228

3329
constructor(config: SSETransportConfig = {}) {
3430
super()
35-
this._sessionId = randomUUID()
31+
this._connections = new Map()
32+
this._sessionId = randomUUID() // Used to validate POST messages belong to this server instance
3633
this._config = {
3734
...DEFAULT_SSE_CONFIG,
3835
...config
@@ -76,11 +73,11 @@ export class SSEServerTransport extends AbstractTransport {
7673
}
7774

7875
return new Promise((resolve) => {
79-
this._server = createServer(async (req, res) => {
76+
this._server = createServer(async (req: IncomingMessage, res: ServerResponse) => {
8077
try {
8178
await this.handleRequest(req, res)
82-
} catch (error) {
83-
logger.error(`Error handling request: ${error}`)
79+
} catch (error: any) {
80+
logger.error(`Error handling request: ${error instanceof Error ? error.message : String(error)}`)
8481
res.writeHead(500).end("Internal Server Error")
8582
}
8683
})
@@ -90,8 +87,8 @@ export class SSEServerTransport extends AbstractTransport {
9087
resolve()
9188
})
9289

93-
this._server.on("error", (error) => {
94-
logger.error(`SSE server error: ${error}`)
90+
this._server.on("error", (error: Error) => {
91+
logger.error(`SSE server error: ${error.message}`)
9592
this._onerror?.(error)
9693
})
9794

@@ -102,7 +99,7 @@ export class SSEServerTransport extends AbstractTransport {
10299
})
103100
}
104101

105-
private async handleRequest(req: ExtendedIncomingMessage, res: ServerResponse): Promise<void> {
102+
private async handleRequest(req: IncomingMessage, res: ServerResponse): Promise<void> {
106103
logger.debug(`Incoming request: ${req.method} ${req.url}`)
107104

108105
if (req.method === "OPTIONS") {
@@ -122,25 +119,23 @@ export class SSEServerTransport extends AbstractTransport {
122119
if (!isAuthenticated) return
123120
}
124121

125-
if (this._sseResponse?.writableEnded) {
126-
this._sseResponse = undefined
127-
}
128-
129-
if (this._sseResponse) {
130-
logger.warn("SSE connection already established; closing the old connection to allow a new one.")
131-
this._sseResponse.end()
132-
this.cleanupConnection()
133-
}
134-
135-
this.setupSSEConnection(res)
136-
return
122+
// Remove check for existing single _sseResponse
123+
// Generate a unique ID for this specific connection
124+
const connectionId = randomUUID();
125+
this.setupSSEConnection(res, connectionId);
126+
return;
137127
}
138128

139129
if (req.method === "POST" && url.pathname === this._config.messageEndpoint) {
140-
if (sessionId !== this._sessionId) {
141-
logger.warn(`Invalid session ID received: ${sessionId}, expected: ${this._sessionId}`)
142-
res.writeHead(403).end("Invalid session ID")
143-
return
130+
// **Connection Validation (User Requested):**
131+
// Check if the 'sessionId' from the POST request URL query parameter
132+
// (which should contain a connectionId provided by the server via the 'endpoint' event)
133+
// corresponds to an active connection in the `_connections` map.
134+
if (!sessionId || !this._connections.has(sessionId)) {
135+
logger.warn(`Invalid or inactive connection ID in POST request URL: ${sessionId}`);
136+
// Use 403 Forbidden as the client is attempting an operation for an invalid/unknown connection
137+
res.writeHead(403).end("Invalid or inactive connection ID");
138+
return;
144139
}
145140

146141
if (this._config.auth?.endpoints?.messages !== false) {
@@ -155,7 +150,7 @@ export class SSEServerTransport extends AbstractTransport {
155150
res.writeHead(404).end("Not Found")
156151
}
157152

158-
private async handleAuthentication(req: ExtendedIncomingMessage, res: ServerResponse, context: string): Promise<boolean> {
153+
private async handleAuthentication(req: IncomingMessage, res: ServerResponse, context: string): Promise<boolean> {
159154
if (!this._config.auth?.provider) {
160155
return true
161156
}
@@ -203,9 +198,8 @@ export class SSEServerTransport extends AbstractTransport {
203198
return true
204199
}
205200

206-
private setupSSEConnection(res: ServerResponse): void {
207-
logger.debug(`Setting up SSE connection for session: ${this._sessionId}`)
208-
201+
private setupSSEConnection(res: ServerResponse, connectionId: string): void {
202+
logger.debug(`Setting up SSE connection: ${connectionId} for server session: ${this._sessionId}`);
209203
const headers = {
210204
...SSE_HEADERS,
211205
...this.getCorsHeaders(),
@@ -218,60 +212,65 @@ export class SSEServerTransport extends AbstractTransport {
218212
res.socket.setNoDelay(true)
219213
res.socket.setTimeout(0)
220214
res.socket.setKeepAlive(true, 1000)
221-
logger.debug('Socket optimized for SSE connection')
215+
logger.debug('Socket optimized for SSE connection');
222216
}
223-
224-
const endpointUrl = `${this._config.messageEndpoint}?sessionId=${this._sessionId}`
225-
logger.debug(`Sending endpoint URL: ${endpointUrl}`)
226-
res.write(`event: endpoint\ndata: ${endpointUrl}\n\n`)
227-
228-
logger.debug('Sending initial keep-alive')
229-
230-
this._keepAliveInterval = setInterval(() => {
231-
if (this._sseResponse && !this._sseResponse.writableEnded) {
232-
try {
233-
this._sseResponse.write(PING_SSE_MESSAGE);
234-
} catch (error) {
235-
logger.error(`Error sending keep-alive: ${error}`)
236-
this.cleanupConnection()
217+
// **Important Change:** The endpoint URL now includes the specific connectionId
218+
// in the 'sessionId' query parameter, as requested by user feedback.
219+
// The client should use this exact URL for subsequent POST messages.
220+
const endpointUrl = `${this._config.messageEndpoint}?sessionId=${connectionId}`;
221+
logger.debug(`Sending endpoint URL for connection ${connectionId}: ${endpointUrl}`);
222+
res.write(`event: endpoint\ndata: ${endpointUrl}\n\n`);
223+
// Send the unique connection ID separately as well for potential client-side use
224+
res.write(`event: connectionId\ndata: ${connectionId}\n\n`);
225+
logger.debug(`Sending initial keep-alive for connection: ${connectionId}`);
226+
const intervalId = setInterval(() => {
227+
const connection = this._connections.get(connectionId);
228+
if (connection && !connection.res.writableEnded) {
229+
try {
230+
connection.res.write(PING_SSE_MESSAGE);
231+
}
232+
catch (error: any) {
233+
logger.error(`Error sending keep-alive for connection ${connectionId}: ${error instanceof Error ? error.message : String(error)}`);
234+
this.cleanupConnection(connectionId);
235+
}
237236
}
238-
}
239-
}, 15000)
240-
241-
this._sseResponse = res
242-
243-
const cleanup = () => this.cleanupConnection()
244-
237+
else {
238+
// Should not happen if cleanup is working, but clear interval just in case
239+
logger.warn(`Keep-alive interval running for missing/ended connection: ${connectionId}`);
240+
this.cleanupConnection(connectionId); // Will clear interval
241+
}
242+
}, 15000);
243+
this._connections.set(connectionId, { res, intervalId });
244+
const cleanup = () => this.cleanupConnection(connectionId);
245245
res.on("close", () => {
246-
logger.info(`SSE connection closed for session: ${this._sessionId}`)
247-
cleanup()
248-
})
249-
250-
res.on("error", (error) => {
251-
logger.error(`SSE connection error for session ${this._sessionId}: ${error}`)
252-
this._onerror?.(error)
253-
cleanup()
254-
})
255-
246+
logger.info(`SSE connection closed: ${connectionId}`);
247+
cleanup();
248+
});
249+
res.on("error", (error: Error) => {
250+
logger.error(`SSE connection error for ${connectionId}: ${error.message}`);
251+
this._onerror?.(error);
252+
cleanup();
253+
});
256254
res.on("end", () => {
257-
logger.info(`SSE connection ended for session: ${this._sessionId}`)
258-
cleanup()
259-
})
260-
261-
logger.info(`SSE connection established successfully for session: ${this._sessionId}`)
255+
logger.info(`SSE connection ended: ${connectionId}`);
256+
cleanup();
257+
});
258+
logger.info(`SSE connection established successfully: ${connectionId}`);
262259
}
263260

264-
private async handlePostMessage(req: ExtendedIncomingMessage, res: ServerResponse): Promise<void> {
265-
if (!this._sseResponse || this._sseResponse.writableEnded) {
266-
logger.warn(`Rejecting message: no active SSE connection for session ${this._sessionId}`)
267-
res.writeHead(409).end("SSE connection not established")
268-
return
261+
private async handlePostMessage(req: IncomingMessage, res: ServerResponse): Promise<void> {
262+
// Check if *any* connection is active, not just the old single _sseResponse
263+
if (this._connections.size === 0) {
264+
logger.warn(`Rejecting message: no active SSE connections for server session ${this._sessionId}`);
265+
// Use 409 Conflict as it indicates the server state prevents fulfilling the request
266+
res.writeHead(409).end("No active SSE connection established");
267+
return;
269268
}
270269

271270
let currentMessage: { id?: string | number; method?: string } = {}
272271

273272
try {
274-
const rawMessage = req.body || await (async () => {
273+
const rawMessage = (req as any).body || await (async () => { // Cast req to any to access potential body property
275274
const ct = contentType.parse(req.headers["content-type"] ?? "")
276275
if (ct.type !== "application/json") {
277276
throw new Error(`Unsupported content-type: ${ct.type}`)
@@ -316,7 +315,7 @@ export class SSEServerTransport extends AbstractTransport {
316315

317316
logger.debug(`Successfully processed message ${rpcMessage.id}`)
318317

319-
} catch (error) {
318+
} catch (error: any) {
320319
const errorMessage = error instanceof Error ? error.message : String(error)
321320
logger.error(`Error handling message for session ${this._sessionId}:`)
322321
logger.error(`- Error: ${errorMessage}`)
@@ -332,7 +331,7 @@ export class SSEServerTransport extends AbstractTransport {
332331
data: {
333332
method: currentMessage.method || "unknown",
334333
sessionId: this._sessionId,
335-
connectionActive: Boolean(this._sseResponse),
334+
connectionActive: Boolean(this._connections.size > 0),
336335
type: "message_handler_error"
337336
}
338337
}
@@ -343,42 +342,85 @@ export class SSEServerTransport extends AbstractTransport {
343342
}
344343
}
345344

345+
// Broadcast message to all connected clients
346346
async send(message: JSONRPCMessage): Promise<void> {
347-
if (!this._sseResponse || this._sseResponse.writableEnded) {
348-
throw new Error("SSE connection not established")
349-
}
350-
351-
this._sseResponse.write(`data: ${JSON.stringify(message)}\n\n`)
347+
if (this._connections.size === 0) {
348+
logger.warn("Attempted to send message, but no clients are connected.");
349+
// Optionally throw an error or just log
350+
// throw new Error("No SSE connections established");
351+
return;
352+
}
353+
const messageString = `data: ${JSON.stringify(message)}\n\n`;
354+
logger.debug(`Broadcasting message to ${this._connections.size} clients: ${JSON.stringify(message)}`);
355+
let failedSends = 0;
356+
for (const [connectionId, connection] of this._connections.entries()) {
357+
if (connection.res && !connection.res.writableEnded) {
358+
try {
359+
connection.res.write(messageString);
360+
}
361+
catch (error: any) {
362+
failedSends++;
363+
logger.error(`Error sending message to connection ${connectionId}: ${error instanceof Error ? error.message : String(error)}`);
364+
// Clean up the problematic connection
365+
this.cleanupConnection(connectionId);
366+
}
367+
}
368+
else {
369+
// Should not happen if cleanup is working, but handle defensively
370+
logger.warn(`Attempted to send to ended connection: ${connectionId}`);
371+
this.cleanupConnection(connectionId);
372+
}
373+
}
374+
if (failedSends > 0) {
375+
logger.warn(`Failed to send message to ${failedSends} connections.`);
376+
}
352377
}
353378

354379
async close(): Promise<void> {
355-
if (this._sseResponse && !this._sseResponse.writableEnded) {
356-
this._sseResponse.end()
357-
}
358-
359-
this.cleanupConnection()
360-
361-
return new Promise((resolve) => {
362-
if (!this._server) {
363-
resolve()
364-
return
380+
logger.info(`Closing SSE transport and ${this._connections.size} connections.`);
381+
// Close all active client connections
382+
for (const connectionId of this._connections.keys()) {
383+
this.cleanupConnection(connectionId, true); // Pass true to end the response
365384
}
366-
367-
this._server.close(() => {
368-
logger.info("SSE server stopped")
369-
this._server = undefined
370-
this._onclose?.()
371-
resolve()
372-
})
373-
})
385+
this._connections.clear(); // Ensure map is empty
386+
// Close the main server
387+
return new Promise((resolve) => {
388+
if (!this._server) {
389+
logger.debug("Server already stopped.");
390+
resolve();
391+
return;
392+
}
393+
this._server.close(() => {
394+
logger.info("SSE server stopped");
395+
this._server = undefined;
396+
this._onclose?.();
397+
resolve();
398+
});
399+
});
374400
}
375401

376-
private cleanupConnection(): void {
377-
if (this._keepAliveInterval) {
378-
clearInterval(this._keepAliveInterval)
379-
this._keepAliveInterval = undefined
380-
}
381-
this._sseResponse = undefined
402+
// Clean up a specific connection by its ID
403+
private cleanupConnection(connectionId: string, endResponse = false): void {
404+
const connection = this._connections.get(connectionId);
405+
if (connection) {
406+
logger.debug(`Cleaning up connection: ${connectionId}`);
407+
if (connection.intervalId) {
408+
clearInterval(connection.intervalId);
409+
}
410+
if (endResponse && connection.res && !connection.res.writableEnded) {
411+
try {
412+
connection.res.end();
413+
}
414+
catch (e: any) {
415+
logger.warn(`Error ending response for connection ${connectionId}: ${e instanceof Error ? e.message : String(e)}`);
416+
}
417+
}
418+
this._connections.delete(connectionId);
419+
logger.debug(`Connection removed: ${connectionId}. Remaining connections: ${this._connections.size}`);
420+
}
421+
else {
422+
logger.debug(`Attempted to clean up non-existent connection: ${connectionId}`);
423+
}
382424
}
383425

384426
isRunning(): boolean {

0 commit comments

Comments
 (0)