1414import java .net .http .HttpResponse ;
1515import java .nio .charset .StandardCharsets ;
1616import java .time .Duration ;
17+ import java .util .ArrayList ;
1718import java .util .List ;
1819import java .util .concurrent .atomic .AtomicBoolean ;
1920import java .util .concurrent .atomic .AtomicReference ;
@@ -43,6 +44,24 @@ public class StreamableHttpClientTransport implements McpClientTransport {
4344
4445 private static final Logger LOGGER = LoggerFactory .getLogger (StreamableHttpClientTransport .class );
4546
47+ private static final String DEFAULT_MCP_ENDPOINT = "/mcp" ;
48+
49+ private static final String MCP_SESSION_ID = "Mcp-Session-Id" ;
50+
51+ private static final String LAST_EVENT_ID = "Last-Event-ID" ;
52+
53+ private static final String ACCEPT = "Accept" ;
54+
55+ private static final String CONTENT_TYPE = "Content-Type" ;
56+
57+ private static final String APPLICATION_JSON = "application/json" ;
58+
59+ private static final String TEXT_EVENT_STREAM = "text/event-stream" ;
60+
61+ private static final String APPLICATION_JSON_SEQ = "application/json-seq" ;
62+
63+ private static final String DEFAULT_ACCEPT_VALUES = "%s, %s" .formatted (APPLICATION_JSON , TEXT_EVENT_STREAM );
64+
4665 private final HttpClientSseClientTransport sseClientTransport ;
4766
4867 private final HttpClient httpClient ;
@@ -57,6 +76,8 @@ public class StreamableHttpClientTransport implements McpClientTransport {
5776
5877 private final AtomicReference <String > lastEventId = new AtomicReference <>();
5978
79+ private final AtomicReference <String > mcpSessionId = new AtomicReference <>();
80+
6081 private final AtomicBoolean fallbackToSse = new AtomicBoolean (false );
6182
6283 StreamableHttpClientTransport (final HttpClient httpClient , final HttpRequest .Builder requestBuilder ,
@@ -96,14 +117,13 @@ public static class Builder {
96117 .version (HttpClient .Version .HTTP_1_1 )
97118 .connectTimeout (Duration .ofSeconds (10 ));
98119
99- private final HttpRequest .Builder requestBuilder = HttpRequest .newBuilder ()
100- .header ("Accept" , "application/json, text/event-stream" );
120+ private final HttpRequest .Builder requestBuilder = HttpRequest .newBuilder ();
101121
102122 private ObjectMapper objectMapper = new ObjectMapper ();
103123
104124 private String baseUri ;
105125
106- private String endpoint = "/mcp" ;
126+ private String endpoint = DEFAULT_MCP_ENDPOINT ;
107127
108128 private Consumer <HttpClient .Builder > clientCustomizer ;
109129
@@ -152,7 +172,7 @@ public StreamableHttpClientTransport build() {
152172 builder .customizeRequest (requestCustomizer );
153173 }
154174
155- if (!endpoint .equals ("/mcp" )) {
175+ if (!endpoint .equals (DEFAULT_MCP_ENDPOINT )) {
156176 builder .sseEndpoint (endpoint );
157177 }
158178
@@ -173,13 +193,24 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
173193 }
174194
175195 return Mono .defer (() -> Mono .fromFuture (() -> {
176- final HttpRequest .Builder builder = requestBuilder .copy ().GET ().uri (uri );
196+ final HttpRequest .Builder request = requestBuilder .copy ().GET (). header ( ACCEPT , TEXT_EVENT_STREAM ).uri (uri );
177197 final String lastId = lastEventId .get ();
178198 if (lastId != null ) {
179- builder .header ("Last-Event-ID" , lastId );
199+ request .header (LAST_EVENT_ID , lastId );
180200 }
181- return httpClient .sendAsync (builder .build (), HttpResponse .BodyHandlers .ofInputStream ());
201+ if (mcpSessionId .get () != null ) {
202+ request .header (MCP_SESSION_ID , mcpSessionId .get ());
203+ }
204+
205+ return httpClient .sendAsync (request .build (), HttpResponse .BodyHandlers .ofInputStream ());
182206 }).flatMap (response -> {
207+ // must like server terminate session and the client need to start a
208+ // new session by sending a new `InitializeRequest` without a session
209+ // ID attached.
210+ if (mcpSessionId .get () != null && response .statusCode () == 404 ) {
211+ mcpSessionId .set (null );
212+ }
213+
183214 if (response .statusCode () == 405 || response .statusCode () == 404 ) {
184215 LOGGER .warn ("Operation not allowed, falling back to SSE" );
185216 fallbackToSse .set (true );
@@ -192,6 +223,7 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
192223 .doOnTerminate (() -> state .set (TransportState .CLOSED ))
193224 .onErrorResume (e -> {
194225 LOGGER .error ("Streamable transport connection error" , e );
226+ state .set (TransportState .DISCONNECTED );
195227 return Mono .error (e );
196228 }));
197229 }
@@ -204,67 +236,52 @@ public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
204236 public Mono <Void > sendMessage (final McpSchema .JSONRPCMessage message ,
205237 final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
206238 if (fallbackToSse .get ()) {
207- return sseClientTransport . sendMessage (message );
239+ return fallbackToSse (message );
208240 }
209241
210242 if (state .get () == TransportState .CLOSED ) {
211243 return Mono .empty ();
212244 }
213245
214- return sentPost (message , handler ).onErrorResume (e -> {
215- LOGGER .error ("Streamable transport sendMessage error" , e );
216- return Mono .error (e );
217- });
218- }
219-
220- /**
221- * Sends a list of messages to the server.
222- * @param messages the list of messages to send
223- * @return a Mono that completes when all messages have been sent
224- */
225- public Mono <Void > sendMessages (final List <McpSchema .JSONRPCMessage > messages ,
226- final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
227- if (fallbackToSse .get ()) {
228- return Flux .fromIterable (messages ).flatMap (this ::sendMessage ).then ();
229- }
230-
231- if (state .get () == TransportState .CLOSED ) {
232- return Mono .empty ();
233- }
234-
235- return sentPost (messages , handler ).onErrorResume (e -> {
236- LOGGER .error ("Streamable transport sendMessages error" , e );
237- return Mono .error (e );
238- });
239- }
240-
241- private Mono <Void > sentPost (final Object msg ,
242- final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
243- return serializeJson (msg ).flatMap (json -> {
244- final HttpRequest request = requestBuilder .copy ()
246+ return serializeJson (message ).flatMap (json -> {
247+ final HttpRequest .Builder request = requestBuilder .copy ()
245248 .POST (HttpRequest .BodyPublishers .ofString (json ))
246- .uri (uri )
247- .build ();
248- return Mono .fromFuture (httpClient .sendAsync (request , HttpResponse .BodyHandlers .ofInputStream ()))
249+ .header (ACCEPT , DEFAULT_ACCEPT_VALUES )
250+ .header (CONTENT_TYPE , APPLICATION_JSON )
251+ .uri (uri );
252+ if (mcpSessionId .get () != null ) {
253+ request .header (MCP_SESSION_ID , mcpSessionId .get ());
254+ }
255+
256+ return Mono .fromFuture (httpClient .sendAsync (request .build (), HttpResponse .BodyHandlers .ofInputStream ()))
249257 .flatMap (response -> {
250258
259+ // server may assign a session ID at initialization time, if yes we
260+ // have to use it for any subsequent requests
261+ if (message instanceof McpSchema .JSONRPCRequest
262+ && ((McpSchema .JSONRPCRequest ) message ).method ().equals (McpSchema .METHOD_INITIALIZE )) {
263+ response .headers ()
264+ .firstValue (MCP_SESSION_ID )
265+ .map (String ::trim )
266+ .ifPresent (this .mcpSessionId ::set );
267+ }
268+
251269 // If the response is 202 Accepted, there's no body to process
252270 if (response .statusCode () == 202 ) {
253271 return Mono .empty ();
254272 }
255273
274+ // must like server terminate session and the client need to start a
275+ // new session by sending a new `InitializeRequest` without a session
276+ // ID attached.
277+ if (mcpSessionId .get () != null && response .statusCode () == 404 ) {
278+ mcpSessionId .set (null );
279+ }
280+
256281 if (response .statusCode () == 405 || response .statusCode () == 404 ) {
257282 LOGGER .warn ("Operation not allowed, falling back to SSE" );
258283 fallbackToSse .set (true );
259- if (msg instanceof McpSchema .JSONRPCMessage message ) {
260- return sseClientTransport .sendMessage (message );
261- }
262-
263- if (msg instanceof List <?> list ) {
264- @ SuppressWarnings ("unchecked" )
265- final List <McpSchema .JSONRPCMessage > messages = (List <McpSchema .JSONRPCMessage >) list ;
266- return Flux .fromIterable (messages ).flatMap (this ::sendMessage ).then ();
267- }
284+ return fallbackToSse (message );
268285 }
269286
270287 if (response .statusCode () >= 400 ) {
@@ -274,18 +291,28 @@ private Mono<Void> sentPost(final Object msg,
274291
275292 return handleStreamingResponse (response , handler );
276293 });
294+ }).onErrorResume (e -> {
295+ LOGGER .error ("Streamable transport sendMessages error" , e );
296+ return Mono .error (e );
277297 });
278298
279299 }
280300
281- private Mono <String > serializeJson (final Object input ) {
301+ private Mono <Void > fallbackToSse (final McpSchema .JSONRPCMessage msg ) {
302+ if (msg instanceof McpSchema .JSONRPCBatchRequest batch ) {
303+ return Flux .fromIterable (batch .items ()).flatMap (sseClientTransport ::sendMessage ).then ();
304+ }
305+
306+ if (msg instanceof McpSchema .JSONRPCBatchResponse batch ) {
307+ return Flux .fromIterable (batch .items ()).flatMap (sseClientTransport ::sendMessage ).then ();
308+ }
309+
310+ return sseClientTransport .sendMessage (msg );
311+ }
312+
313+ private Mono <String > serializeJson (final McpSchema .JSONRPCMessage msg ) {
282314 try {
283- if (input instanceof McpSchema .JSONRPCMessage || input instanceof List ) {
284- return Mono .just (objectMapper .writeValueAsString (input ));
285- }
286- else {
287- return Mono .error (new IllegalArgumentException ("Unsupported message type for serialization" ));
288- }
315+ return Mono .just (objectMapper .writeValueAsString (msg ));
289316 }
290317 catch (IOException e ) {
291318 LOGGER .error ("Error serializing JSON-RPC message" , e );
@@ -295,27 +322,31 @@ private Mono<String> serializeJson(final Object input) {
295322
296323 private Mono <Void > handleStreamingResponse (final HttpResponse <InputStream > response ,
297324 final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
298- final String contentType = response .headers ().firstValue ("Content-Type" ).orElse ("" );
299- if (contentType .contains ("application/json-seq" )) {
325+ final String contentType = response .headers ().firstValue (CONTENT_TYPE ).orElse ("" );
326+ if (contentType .contains (APPLICATION_JSON_SEQ )) {
300327 return handleJsonStream (response , handler );
301328 }
302- else if (contentType .contains ("text/event-stream" )) {
329+ else if (contentType .contains (TEXT_EVENT_STREAM )) {
303330 return handleSseStream (response , handler );
304331 }
305- else if (contentType .contains ("application/json" )) {
332+ else if (contentType .contains (APPLICATION_JSON )) {
306333 return handleSingleJson (response , handler );
307334 }
308- else {
309- return Mono .error (new UnsupportedOperationException ("Unsupported Content-Type: " + contentType ));
310- }
335+ return Mono .error (new UnsupportedOperationException ("Unsupported Content-Type: " + contentType ));
311336 }
312337
313338 private Mono <Void > handleSingleJson (final HttpResponse <InputStream > response ,
314339 final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
315340 return Mono .fromCallable (() -> {
316- final McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper ,
317- new String (response .body ().readAllBytes (), StandardCharsets .UTF_8 ));
318- return handler .apply (Mono .just (msg ));
341+ try {
342+ final McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper ,
343+ new String (response .body ().readAllBytes (), StandardCharsets .UTF_8 ));
344+ return handler .apply (Mono .just (msg ));
345+ }
346+ catch (IOException e ) {
347+ LOGGER .error ("Error processing JSON response" , e );
348+ return Mono .error (e );
349+ }
319350 }).flatMap (Function .identity ()).then ();
320351 }
321352
@@ -328,7 +359,7 @@ private Mono<Void> handleJsonStream(final HttpResponse<InputStream> response,
328359 }
329360 catch (IOException e ) {
330361 LOGGER .error ("Error processing JSON line" , e );
331- return Mono .empty ( );
362+ return Mono .error ( e );
332363 }
333364 }).then ();
334365 }
@@ -347,7 +378,7 @@ private Mono<Void> handleSseStream(final HttpResponse<InputStream> response,
347378 if (line .startsWith ("event: " ))
348379 event = line .substring (7 ).trim ();
349380 else if (line .startsWith ("data: " ))
350- data += line .substring (6 ). trim () + "\n " ;
381+ data += line .substring (6 ) + "\n " ;
351382 else if (line .startsWith ("id: " ))
352383 id = line .substring (4 ).trim ();
353384 }
@@ -356,34 +387,39 @@ else if (line.startsWith("id: "))
356387 data = data .substring (0 , data .length () - 1 );
357388 }
358389
359- return new FlowSseClient .SseEvent (event , data , id );
390+ return new FlowSseClient .SseEvent (id , event , data );
360391 })
361392 .filter (sseEvent -> "message" .equals (sseEvent .type ()))
362- .doOnNext (sseEvent -> {
363- lastEventId . set ( sseEvent .id () );
393+ .concatMap (sseEvent -> {
394+ String rawData = sseEvent .data (). trim ( );
364395 try {
365- String rawData = sseEvent .data ().trim ();
366396 JsonNode node = objectMapper .readTree (rawData );
367-
397+ List < McpSchema . JSONRPCMessage > messages = new ArrayList <>();
368398 if (node .isArray ()) {
369399 for (JsonNode item : node ) {
370- String rawMessage = objectMapper .writeValueAsString (item );
371- McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper ,
372- rawMessage );
373- handler .apply (Mono .just (msg )).subscribe ();
400+ messages .add (McpSchema .deserializeJsonRpcMessage (objectMapper , item .toString ()));
374401 }
375402 }
376403 else if (node .isObject ()) {
377- String rawMessage = objectMapper .writeValueAsString (node );
378- McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper , rawMessage );
379- handler .apply (Mono .just (msg )).subscribe ();
404+ messages .add (McpSchema .deserializeJsonRpcMessage (objectMapper , node .toString ()));
380405 }
381406 else {
382- LOGGER .warn ("Unexpected JSON in SSE data: {}" , rawData );
407+ String warning = "Unexpected JSON in SSE data: " + rawData ;
408+ LOGGER .warn (warning );
409+ return Mono .error (new IllegalArgumentException (warning ));
383410 }
411+
412+ return Flux .fromIterable (messages )
413+ .concatMap (msg -> handler .apply (Mono .just (msg )))
414+ .then (Mono .fromRunnable (() -> {
415+ if (!sseEvent .id ().isEmpty ()) {
416+ lastEventId .set (sseEvent .id ());
417+ }
418+ }));
384419 }
385420 catch (IOException e ) {
386- LOGGER .error ("Error processing SSE event: {}" , sseEvent .data (), e );
421+ LOGGER .error ("Error parsing SSE JSON: {}" , rawData , e );
422+ return Mono .error (e );
387423 }
388424 })
389425 .then ();
0 commit comments