1717import java .util .concurrent .atomic .AtomicBoolean ;
1818import java .util .concurrent .atomic .AtomicLong ;
1919import java .util .function .Supplier ;
20+ import java .util .stream .Collectors ;
2021
2122import com .fasterxml .jackson .core .type .TypeReference ;
2223import com .fasterxml .jackson .databind .ObjectMapper ;
3031import io .modelcontextprotocol .spec .McpServerSession ;
3132import io .modelcontextprotocol .spec .McpServerTransport ;
3233import io .modelcontextprotocol .spec .McpServerTransportProvider ;
34+ import io .modelcontextprotocol .spec .SseEvent ;
3335import io .modelcontextprotocol .util .Assert ;
3436import jakarta .servlet .AsyncContext ;
3537import jakarta .servlet .ReadListener ;
@@ -89,6 +91,8 @@ public class StreamableHttpServerTransportProvider extends HttpServlet implement
8991
9092 public static final String ALLOW_ORIGIN_DEFAULT_VALUE = "*" ;
9193
94+ public static final String PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version" ;
95+
9296 public static final String CACHE_CONTROL_HEADER = "Cache-Control" ;
9397
9498 public static final String CONNECTION_HEADER = "Connection" ;
@@ -117,7 +121,7 @@ public class StreamableHttpServerTransportProvider extends HttpServlet implement
117121 private final Supplier <String > sessionIdProvider ;
118122
119123 /** Sessions map, keyed by Session ID */
120- private final Map <String , McpServerSession > sessions = new ConcurrentHashMap <>();
124+ private static final Map <String , McpServerSession > sessions = new ConcurrentHashMap <>();
121125
122126 /** Flag indicating if the transport is in the process of shutting down */
123127 private final AtomicBoolean isClosing = new AtomicBoolean (false );
@@ -128,6 +132,7 @@ public class StreamableHttpServerTransportProvider extends HttpServlet implement
128132 /** Callback interface for session lifecycle and errors */
129133 private SessionHandler sessionHandler ;
130134
135+ /** Factory for McpServerSession takes session IDs */
131136 private McpServerSession .StreamableHttpSessionFactory streamableHttpSessionFactory ;
132137
133138 /**
@@ -242,6 +247,13 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
242247 return ;
243248 }
244249
250+ // Delayed until version negotiation is implemented.
251+ /*
252+ * if (session.getState().equals(session.STATE_INITIALIZED) &&
253+ * request.getHeader(PROTOCOL_VERSION_HEADER) == null) {
254+ * sendErrorResponse(response, "Protocol version missing in request header"); }
255+ */
256+
245257 // Set up SSE connection
246258 response .setContentType (TEXT_EVENT_STREAM );
247259 response .setCharacterEncoding (UTF_8 );
@@ -254,10 +266,18 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
254266
255267 String lastEventId = request .getHeader (LAST_EVENT_ID_HEADER );
256268
257- SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , lastEventId );
258- session .registerTransport (session .LISTENING_TRANSPORT , sseTransport );
259-
260- logger .debug ("Registered SSE transport {} for session {}" , session .LISTENING_TRANSPORT , sessionId );
269+ if (lastEventId == null ) { // Just opening a listening stream
270+ SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , lastEventId ,
271+ session .LISTENING_TRANSPORT , sessionId );
272+ session .registerTransport (session .LISTENING_TRANSPORT , sseTransport );
273+ logger .debug ("Registered SSE transport {} for session {}" , session .LISTENING_TRANSPORT , sessionId );
274+ }
275+ else { // Asking for a stream to replay events from a previous request
276+ SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , lastEventId ,
277+ request .getRequestId (), sessionId );
278+ session .registerTransport (request .getRequestId (), sseTransport );
279+ logger .debug ("Registered SSE transport {} for session {}" , session .LISTENING_TRANSPORT , sessionId );
280+ }
261281 }
262282
263283 @ Override
@@ -328,6 +348,15 @@ public void onAllDataRead() throws IOException {
328348 asyncContext .complete ();
329349 return ;
330350 }
351+
352+ // Delayed until version negotiation is implemented.
353+ /*
354+ * if (session.getState().equals(session.STATE_INITIALIZED) &&
355+ * request.getHeader(PROTOCOL_VERSION_HEADER) == null) {
356+ * sendErrorResponse(response,
357+ * "Protocol version missing in request header"); }
358+ */
359+
331360 logger .debug ("Using session: {}" , sessionId );
332361
333362 response .setHeader (SESSION_ID_HEADER , sessionId );
@@ -362,7 +391,8 @@ else if (id instanceof Integer) {
362391 response .setHeader (CACHE_CONTROL_HEADER , CACHE_CONTROL_NO_CACHE );
363392 response .setHeader (CONNECTION_HEADER , CONNECTION_KEEP_ALIVE );
364393
365- SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , null );
394+ SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , null ,
395+ transportId , sessionId );
366396 session .registerTransport (transportId , sseTransport );
367397 }
368398 else {
@@ -650,13 +680,17 @@ private static class SseTransport implements McpServerTransport {
650680
651681 private final Map <String , SseEvent > eventHistory = new ConcurrentHashMap <>();
652682
653- private final AtomicLong eventCounter = new AtomicLong (0 );
683+ private final String id ;
684+
685+ private final String sessionId ;
654686
655687 public SseTransport (ObjectMapper objectMapper , HttpServletResponse response , AsyncContext asyncContext ,
656- String lastEventId ) {
688+ String lastEventId , String transportId , String sessionId ) {
657689 this .objectMapper = objectMapper ;
658690 this .response = response ;
659691 this .asyncContext = asyncContext ;
692+ this .id = transportId ;
693+ this .sessionId = sessionId ;
660694
661695 setupSseStream (lastEventId );
662696 }
@@ -710,9 +744,19 @@ private void setupSseStream(String lastEventId) {
710744
711745 private void replayEventsAfter (String lastEventId ) {
712746 try {
713- long lastId = Long .parseLong (lastEventId );
714- for (long i = lastId + 1 ; i <= eventCounter .get (); i ++) {
715- SseEvent event = eventHistory .get (String .valueOf (i ));
747+ McpServerSession session = sessions .get (sessionId );
748+ String transportIdOfLastEventId = session .getTransportIdForEvent (lastEventId );
749+ Map <String , SseEvent > transportEventHistory = session
750+ .getTransportEventHistory (transportIdOfLastEventId );
751+ List <String > eventIds = transportEventHistory .keySet ()
752+ .stream ()
753+ .map (Long ::parseLong )
754+ .filter (key -> key > Long .parseLong (lastEventId ))
755+ .sorted ()
756+ .map (String ::valueOf )
757+ .collect (Collectors .toList ());
758+ for (String eventId : eventIds ) {
759+ SseEvent event = transportEventHistory .get (eventId );
716760 if (event != null ) {
717761 eventSink .tryEmitNext (event );
718762 }
@@ -727,7 +771,7 @@ private void replayEventsAfter(String lastEventId) {
727771 public Mono <Void > sendMessage (JSONRPCMessage message ) {
728772 try {
729773 String jsonText = objectMapper .writeValueAsString (message );
730- String eventId = String . valueOf ( eventCounter . incrementAndGet () );
774+ String eventId = sessions . get ( sessionId ). incrementAndGetEventId ( id );
731775 SseEvent event = new SseEvent (eventId , MESSAGE_EVENT_TYPE , jsonText );
732776
733777 eventHistory .put (eventId , event );
@@ -737,6 +781,7 @@ public Mono<Void> sendMessage(JSONRPCMessage message) {
737781 if (message instanceof McpSchema .JSONRPCResponse ) {
738782 logger .debug ("Completing SSE stream after sending response" );
739783 eventSink .tryEmitComplete ();
784+ sessions .get (sessionId ).setTransportEventHistory (id , eventHistory );
740785 }
741786
742787 return Mono .empty ();
@@ -754,7 +799,7 @@ public Mono<Void> sendMessageStream(Flux<JSONRPCMessage> messageStream) {
754799 return messageStream .doOnNext (message -> {
755800 try {
756801 String jsonText = objectMapper .writeValueAsString (message );
757- String eventId = String . valueOf ( eventCounter . incrementAndGet () );
802+ String eventId = sessions . get ( sessionId ). incrementAndGetEventId ( id );
758803 SseEvent event = new SseEvent (eventId , MESSAGE_EVENT_TYPE , jsonText );
759804
760805 eventHistory .put (eventId , event );
@@ -768,6 +813,7 @@ public Mono<Void> sendMessageStream(Flux<JSONRPCMessage> messageStream) {
768813 }).doOnComplete (() -> {
769814 logger .debug ("Completing SSE stream after sending all stream messages" );
770815 eventSink .tryEmitComplete ();
816+ sessions .get (sessionId ).setTransportEventHistory (id , eventHistory );
771817 }).then ();
772818 }
773819
@@ -780,13 +826,11 @@ public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
780826 public Mono <Void > closeGracefully () {
781827 return Mono .fromRunnable (() -> {
782828 eventSink .tryEmitComplete ();
829+ sessions .get (sessionId ).setTransportEventHistory (id , eventHistory );
783830 logger .debug ("SSE transport closed gracefully" );
784831 });
785832 }
786833
787- private record SseEvent (String id , String event , String data ) {
788- }
789-
790834 }
791835
792836 /**
0 commit comments