1- import { randomUUID } from "node:crypto"
2- import { IncomingMessage , Server as HttpServer , ServerResponse , createServer } from "node:http"
3- import { JSONRPCMessage , ClientRequest } from "@modelcontextprotocol/sdk/types.js"
4- import contentType from "content-type"
5- import getRawBody from "raw-body"
6- import { APIKeyAuthProvider } from "../../auth/providers/apikey.js"
7- import { DEFAULT_AUTH_ERROR } from "../../auth/types.js"
8- import { AbstractTransport } from "../base.js"
9- import { DEFAULT_SSE_CONFIG , SSETransportConfig , SSETransportConfigInternal , DEFAULT_CORS_CONFIG , CORSConfig } from "./types.js"
10- import { logger } from "../../core/Logger.js"
11- import { getRequestHeader , setResponseHeaders } from "../../utils/headers.js"
1+ import { randomUUID } from "node:crypto" ;
2+ import { IncomingMessage , Server as HttpServer , ServerResponse , createServer } from "node:http" ;
3+ import { JSONRPCMessage , ClientRequest } from "@modelcontextprotocol/sdk/types.js" ;
4+ import contentType from "content-type" ;
5+ import getRawBody from "raw-body" ;
6+ import { APIKeyAuthProvider } from "../../auth/providers/apikey.js" ;
7+ import { DEFAULT_AUTH_ERROR } from "../../auth/types.js" ;
8+ import { AbstractTransport } from "../base.js" ;
9+ import { DEFAULT_SSE_CONFIG , SSETransportConfig , SSETransportConfigInternal , DEFAULT_CORS_CONFIG , CORSConfig } from "./types.js" ;
10+ import { logger } from "../../core/Logger.js" ;
11+ import { getRequestHeader , setResponseHeaders } from "../../utils/headers.js" ;
1212import { PING_SSE_MESSAGE } from "../utils/ping-message.js" ;
1313
14- interface ExtendedIncomingMessage extends IncomingMessage {
15- body ?: ClientRequest
16- }
1714
1815const SSE_HEADERS = {
1916 "Content-Type" : "text/event-stream" ,
@@ -25,14 +22,14 @@ export class SSEServerTransport extends AbstractTransport {
2522 readonly type = "sse"
2623
2724 private _server ?: HttpServer
28- private _sseResponse ?: ServerResponse
29- private _sessionId : string
25+ private _connections : Map < string , { res : ServerResponse , intervalId : NodeJS . Timeout } > // Map<connectionId, { res: ServerResponse, intervalId: NodeJS.Timeout }>
26+ private _sessionId : string // Server instance ID
3027 private _config : SSETransportConfigInternal
31- private _keepAliveInterval ?: NodeJS . Timeout
3228
3329 constructor ( config : SSETransportConfig = { } ) {
3430 super ( )
35- this . _sessionId = randomUUID ( )
31+ this . _connections = new Map ( )
32+ this . _sessionId = randomUUID ( ) // Used to validate POST messages belong to this server instance
3633 this . _config = {
3734 ...DEFAULT_SSE_CONFIG ,
3835 ...config
@@ -76,11 +73,11 @@ export class SSEServerTransport extends AbstractTransport {
7673 }
7774
7875 return new Promise ( ( resolve ) => {
79- this . _server = createServer ( async ( req , res ) => {
76+ this . _server = createServer ( async ( req : IncomingMessage , res : ServerResponse ) => {
8077 try {
8178 await this . handleRequest ( req , res )
82- } catch ( error ) {
83- logger . error ( `Error handling request: ${ error } ` )
79+ } catch ( error : any ) {
80+ logger . error ( `Error handling request: ${ error instanceof Error ? error . message : String ( error ) } ` )
8481 res . writeHead ( 500 ) . end ( "Internal Server Error" )
8582 }
8683 } )
@@ -90,8 +87,8 @@ export class SSEServerTransport extends AbstractTransport {
9087 resolve ( )
9188 } )
9289
93- this . _server . on ( "error" , ( error ) => {
94- logger . error ( `SSE server error: ${ error } ` )
90+ this . _server . on ( "error" , ( error : Error ) => {
91+ logger . error ( `SSE server error: ${ error . message } ` )
9592 this . _onerror ?.( error )
9693 } )
9794
@@ -102,7 +99,7 @@ export class SSEServerTransport extends AbstractTransport {
10299 } )
103100 }
104101
105- private async handleRequest ( req : ExtendedIncomingMessage , res : ServerResponse ) : Promise < void > {
102+ private async handleRequest ( req : IncomingMessage , res : ServerResponse ) : Promise < void > {
106103 logger . debug ( `Incoming request: ${ req . method } ${ req . url } ` )
107104
108105 if ( req . method === "OPTIONS" ) {
@@ -122,25 +119,23 @@ export class SSEServerTransport extends AbstractTransport {
122119 if ( ! isAuthenticated ) return
123120 }
124121
125- if ( this . _sseResponse ?. writableEnded ) {
126- this . _sseResponse = undefined
127- }
128-
129- if ( this . _sseResponse ) {
130- logger . warn ( "SSE connection already established; closing the old connection to allow a new one." )
131- this . _sseResponse . end ( )
132- this . cleanupConnection ( )
133- }
134-
135- this . setupSSEConnection ( res )
136- return
122+ // Remove check for existing single _sseResponse
123+ // Generate a unique ID for this specific connection
124+ const connectionId = randomUUID ( ) ;
125+ this . setupSSEConnection ( res , connectionId ) ;
126+ return ;
137127 }
138128
139129 if ( req . method === "POST" && url . pathname === this . _config . messageEndpoint ) {
140- if ( sessionId !== this . _sessionId ) {
141- logger . warn ( `Invalid session ID received: ${ sessionId } , expected: ${ this . _sessionId } ` )
142- res . writeHead ( 403 ) . end ( "Invalid session ID" )
143- return
130+ // **Connection Validation (User Requested):**
131+ // Check if the 'sessionId' from the POST request URL query parameter
132+ // (which should contain a connectionId provided by the server via the 'endpoint' event)
133+ // corresponds to an active connection in the `_connections` map.
134+ if ( ! sessionId || ! this . _connections . has ( sessionId ) ) {
135+ logger . warn ( `Invalid or inactive connection ID in POST request URL: ${ sessionId } ` ) ;
136+ // Use 403 Forbidden as the client is attempting an operation for an invalid/unknown connection
137+ res . writeHead ( 403 ) . end ( "Invalid or inactive connection ID" ) ;
138+ return ;
144139 }
145140
146141 if ( this . _config . auth ?. endpoints ?. messages !== false ) {
@@ -155,7 +150,7 @@ export class SSEServerTransport extends AbstractTransport {
155150 res . writeHead ( 404 ) . end ( "Not Found" )
156151 }
157152
158- private async handleAuthentication ( req : ExtendedIncomingMessage , res : ServerResponse , context : string ) : Promise < boolean > {
153+ private async handleAuthentication ( req : IncomingMessage , res : ServerResponse , context : string ) : Promise < boolean > {
159154 if ( ! this . _config . auth ?. provider ) {
160155 return true
161156 }
@@ -203,9 +198,8 @@ export class SSEServerTransport extends AbstractTransport {
203198 return true
204199 }
205200
206- private setupSSEConnection ( res : ServerResponse ) : void {
207- logger . debug ( `Setting up SSE connection for session: ${ this . _sessionId } ` )
208-
201+ private setupSSEConnection ( res : ServerResponse , connectionId : string ) : void {
202+ logger . debug ( `Setting up SSE connection: ${ connectionId } for server session: ${ this . _sessionId } ` ) ;
209203 const headers = {
210204 ...SSE_HEADERS ,
211205 ...this . getCorsHeaders ( ) ,
@@ -218,60 +212,65 @@ export class SSEServerTransport extends AbstractTransport {
218212 res . socket . setNoDelay ( true )
219213 res . socket . setTimeout ( 0 )
220214 res . socket . setKeepAlive ( true , 1000 )
221- logger . debug ( 'Socket optimized for SSE connection' )
215+ logger . debug ( 'Socket optimized for SSE connection' ) ;
222216 }
223-
224- const endpointUrl = `${ this . _config . messageEndpoint } ?sessionId=${ this . _sessionId } `
225- logger . debug ( `Sending endpoint URL: ${ endpointUrl } ` )
226- res . write ( `event: endpoint\ndata: ${ endpointUrl } \n\n` )
227-
228- logger . debug ( 'Sending initial keep-alive' )
229-
230- this . _keepAliveInterval = setInterval ( ( ) => {
231- if ( this . _sseResponse && ! this . _sseResponse . writableEnded ) {
232- try {
233- this . _sseResponse . write ( PING_SSE_MESSAGE ) ;
234- } catch ( error ) {
235- logger . error ( `Error sending keep-alive: ${ error } ` )
236- this . cleanupConnection ( )
217+ // **Important Change:** The endpoint URL now includes the specific connectionId
218+ // in the 'sessionId' query parameter, as requested by user feedback.
219+ // The client should use this exact URL for subsequent POST messages.
220+ const endpointUrl = `${ this . _config . messageEndpoint } ?sessionId=${ connectionId } ` ;
221+ logger . debug ( `Sending endpoint URL for connection ${ connectionId } : ${ endpointUrl } ` ) ;
222+ res . write ( `event: endpoint\ndata: ${ endpointUrl } \n\n` ) ;
223+ // Send the unique connection ID separately as well for potential client-side use
224+ res . write ( `event: connectionId\ndata: ${ connectionId } \n\n` ) ;
225+ logger . debug ( `Sending initial keep-alive for connection: ${ connectionId } ` ) ;
226+ const intervalId = setInterval ( ( ) => {
227+ const connection = this . _connections . get ( connectionId ) ;
228+ if ( connection && ! connection . res . writableEnded ) {
229+ try {
230+ connection . res . write ( PING_SSE_MESSAGE ) ;
231+ }
232+ catch ( error : any ) {
233+ logger . error ( `Error sending keep-alive for connection ${ connectionId } : ${ error instanceof Error ? error . message : String ( error ) } ` ) ;
234+ this . cleanupConnection ( connectionId ) ;
235+ }
237236 }
238- }
239- } , 15000 )
240-
241- this . _sseResponse = res
242-
243- const cleanup = ( ) => this . cleanupConnection ( )
244-
237+ else {
238+ // Should not happen if cleanup is working, but clear interval just in case
239+ logger . warn ( `Keep-alive interval running for missing/ended connection: ${ connectionId } ` ) ;
240+ this . cleanupConnection ( connectionId ) ; // Will clear interval
241+ }
242+ } , 15000 ) ;
243+ this . _connections . set ( connectionId , { res, intervalId } ) ;
244+ const cleanup = ( ) => this . cleanupConnection ( connectionId ) ;
245245 res . on ( "close" , ( ) => {
246- logger . info ( `SSE connection closed for session: ${ this . _sessionId } ` )
247- cleanup ( )
248- } )
249-
250- res . on ( "error" , ( error ) => {
251- logger . error ( `SSE connection error for session ${ this . _sessionId } : ${ error } ` )
252- this . _onerror ?.( error )
253- cleanup ( )
254- } )
255-
246+ logger . info ( `SSE connection closed: ${ connectionId } ` ) ;
247+ cleanup ( ) ;
248+ } ) ;
249+ res . on ( "error" , ( error : Error ) => {
250+ logger . error ( `SSE connection error for ${ connectionId } : ${ error . message } ` ) ;
251+ this . _onerror ?.( error ) ;
252+ cleanup ( ) ;
253+ } ) ;
256254 res . on ( "end" , ( ) => {
257- logger . info ( `SSE connection ended for session: ${ this . _sessionId } ` )
258- cleanup ( )
259- } )
260-
261- logger . info ( `SSE connection established successfully for session: ${ this . _sessionId } ` )
255+ logger . info ( `SSE connection ended: ${ connectionId } ` ) ;
256+ cleanup ( ) ;
257+ } ) ;
258+ logger . info ( `SSE connection established successfully: ${ connectionId } ` ) ;
262259 }
263260
264- private async handlePostMessage ( req : ExtendedIncomingMessage , res : ServerResponse ) : Promise < void > {
265- if ( ! this . _sseResponse || this . _sseResponse . writableEnded ) {
266- logger . warn ( `Rejecting message: no active SSE connection for session ${ this . _sessionId } ` )
267- res . writeHead ( 409 ) . end ( "SSE connection not established" )
268- return
261+ private async handlePostMessage ( req : IncomingMessage , res : ServerResponse ) : Promise < void > {
262+ // Check if *any* connection is active, not just the old single _sseResponse
263+ if ( this . _connections . size === 0 ) {
264+ logger . warn ( `Rejecting message: no active SSE connections for server session ${ this . _sessionId } ` ) ;
265+ // Use 409 Conflict as it indicates the server state prevents fulfilling the request
266+ res . writeHead ( 409 ) . end ( "No active SSE connection established" ) ;
267+ return ;
269268 }
270269
271270 let currentMessage : { id ?: string | number ; method ?: string } = { }
272271
273272 try {
274- const rawMessage = req . body || await ( async ( ) => {
273+ const rawMessage = ( req as any ) . body || await ( async ( ) => { // Cast req to any to access potential body property
275274 const ct = contentType . parse ( req . headers [ "content-type" ] ?? "" )
276275 if ( ct . type !== "application/json" ) {
277276 throw new Error ( `Unsupported content-type: ${ ct . type } ` )
@@ -316,7 +315,7 @@ export class SSEServerTransport extends AbstractTransport {
316315
317316 logger . debug ( `Successfully processed message ${ rpcMessage . id } ` )
318317
319- } catch ( error ) {
318+ } catch ( error : any ) {
320319 const errorMessage = error instanceof Error ? error . message : String ( error )
321320 logger . error ( `Error handling message for session ${ this . _sessionId } :` )
322321 logger . error ( `- Error: ${ errorMessage } ` )
@@ -332,7 +331,7 @@ export class SSEServerTransport extends AbstractTransport {
332331 data : {
333332 method : currentMessage . method || "unknown" ,
334333 sessionId : this . _sessionId ,
335- connectionActive : Boolean ( this . _sseResponse ) ,
334+ connectionActive : Boolean ( this . _connections . size > 0 ) ,
336335 type : "message_handler_error"
337336 }
338337 }
@@ -343,42 +342,85 @@ export class SSEServerTransport extends AbstractTransport {
343342 }
344343 }
345344
345+ // Broadcast message to all connected clients
346346 async send ( message : JSONRPCMessage ) : Promise < void > {
347- if ( ! this . _sseResponse || this . _sseResponse . writableEnded ) {
348- throw new Error ( "SSE connection not established" )
349- }
350-
351- this . _sseResponse . write ( `data: ${ JSON . stringify ( message ) } \n\n` )
347+ if ( this . _connections . size === 0 ) {
348+ logger . warn ( "Attempted to send message, but no clients are connected." ) ;
349+ // Optionally throw an error or just log
350+ // throw new Error("No SSE connections established");
351+ return ;
352+ }
353+ const messageString = `data: ${ JSON . stringify ( message ) } \n\n` ;
354+ logger . debug ( `Broadcasting message to ${ this . _connections . size } clients: ${ JSON . stringify ( message ) } ` ) ;
355+ let failedSends = 0 ;
356+ for ( const [ connectionId , connection ] of this . _connections . entries ( ) ) {
357+ if ( connection . res && ! connection . res . writableEnded ) {
358+ try {
359+ connection . res . write ( messageString ) ;
360+ }
361+ catch ( error : any ) {
362+ failedSends ++ ;
363+ logger . error ( `Error sending message to connection ${ connectionId } : ${ error instanceof Error ? error . message : String ( error ) } ` ) ;
364+ // Clean up the problematic connection
365+ this . cleanupConnection ( connectionId ) ;
366+ }
367+ }
368+ else {
369+ // Should not happen if cleanup is working, but handle defensively
370+ logger . warn ( `Attempted to send to ended connection: ${ connectionId } ` ) ;
371+ this . cleanupConnection ( connectionId ) ;
372+ }
373+ }
374+ if ( failedSends > 0 ) {
375+ logger . warn ( `Failed to send message to ${ failedSends } connections.` ) ;
376+ }
352377 }
353378
354379 async close ( ) : Promise < void > {
355- if ( this . _sseResponse && ! this . _sseResponse . writableEnded ) {
356- this . _sseResponse . end ( )
357- }
358-
359- this . cleanupConnection ( )
360-
361- return new Promise ( ( resolve ) => {
362- if ( ! this . _server ) {
363- resolve ( )
364- return
380+ logger . info ( `Closing SSE transport and ${ this . _connections . size } connections.` ) ;
381+ // Close all active client connections
382+ for ( const connectionId of this . _connections . keys ( ) ) {
383+ this . cleanupConnection ( connectionId , true ) ; // Pass true to end the response
365384 }
366-
367- this . _server . close ( ( ) => {
368- logger . info ( "SSE server stopped" )
369- this . _server = undefined
370- this . _onclose ?.( )
371- resolve ( )
372- } )
373- } )
385+ this . _connections . clear ( ) ; // Ensure map is empty
386+ // Close the main server
387+ return new Promise ( ( resolve ) => {
388+ if ( ! this . _server ) {
389+ logger . debug ( "Server already stopped." ) ;
390+ resolve ( ) ;
391+ return ;
392+ }
393+ this . _server . close ( ( ) => {
394+ logger . info ( "SSE server stopped" ) ;
395+ this . _server = undefined ;
396+ this . _onclose ?.( ) ;
397+ resolve ( ) ;
398+ } ) ;
399+ } ) ;
374400 }
375401
376- private cleanupConnection ( ) : void {
377- if ( this . _keepAliveInterval ) {
378- clearInterval ( this . _keepAliveInterval )
379- this . _keepAliveInterval = undefined
380- }
381- this . _sseResponse = undefined
402+ // Clean up a specific connection by its ID
403+ private cleanupConnection ( connectionId : string , endResponse = false ) : void {
404+ const connection = this . _connections . get ( connectionId ) ;
405+ if ( connection ) {
406+ logger . debug ( `Cleaning up connection: ${ connectionId } ` ) ;
407+ if ( connection . intervalId ) {
408+ clearInterval ( connection . intervalId ) ;
409+ }
410+ if ( endResponse && connection . res && ! connection . res . writableEnded ) {
411+ try {
412+ connection . res . end ( ) ;
413+ }
414+ catch ( e : any ) {
415+ logger . warn ( `Error ending response for connection ${ connectionId } : ${ e instanceof Error ? e . message : String ( e ) } ` ) ;
416+ }
417+ }
418+ this . _connections . delete ( connectionId ) ;
419+ logger . debug ( `Connection removed: ${ connectionId } . Remaining connections: ${ this . _connections . size } ` ) ;
420+ }
421+ else {
422+ logger . debug ( `Attempted to clean up non-existent connection: ${ connectionId } ` ) ;
423+ }
382424 }
383425
384426 isRunning ( ) : boolean {
0 commit comments