1414import java .net .http .HttpResponse ;
1515import java .nio .charset .StandardCharsets ;
1616import java .time .Duration ;
17+ import java .util .ArrayList ;
1718import java .util .List ;
1819import java .util .concurrent .atomic .AtomicBoolean ;
1920import java .util .concurrent .atomic .AtomicReference ;
@@ -192,6 +193,7 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
192193 .doOnTerminate (() -> state .set (TransportState .CLOSED ))
193194 .onErrorResume (e -> {
194195 LOGGER .error ("Streamable transport connection error" , e );
196+ state .set (TransportState .DISCONNECTED );
195197 return Mono .error (e );
196198 }));
197199 }
@@ -204,43 +206,14 @@ public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
204206 public Mono <Void > sendMessage (final McpSchema .JSONRPCMessage message ,
205207 final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
206208 if (fallbackToSse .get ()) {
207- return sseClientTransport . sendMessage (message );
209+ return fallbackToSse (message );
208210 }
209211
210212 if (state .get () == TransportState .CLOSED ) {
211213 return Mono .empty ();
212214 }
213215
214- return sentPost (message , handler ).onErrorResume (e -> {
215- LOGGER .error ("Streamable transport sendMessage error" , e );
216- return Mono .error (e );
217- });
218- }
219-
220- /**
221- * Sends a list of messages to the server.
222- * @param messages the list of messages to send
223- * @return a Mono that completes when all messages have been sent
224- */
225- public Mono <Void > sendMessages (final List <McpSchema .JSONRPCMessage > messages ,
226- final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
227- if (fallbackToSse .get ()) {
228- return Flux .fromIterable (messages ).flatMap (this ::sendMessage ).then ();
229- }
230-
231- if (state .get () == TransportState .CLOSED ) {
232- return Mono .empty ();
233- }
234-
235- return sentPost (messages , handler ).onErrorResume (e -> {
236- LOGGER .error ("Streamable transport sendMessages error" , e );
237- return Mono .error (e );
238- });
239- }
240-
241- private Mono <Void > sentPost (final Object msg ,
242- final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
243- return serializeJson (msg ).flatMap (json -> {
216+ return serializeJson (message ).flatMap (json -> {
244217 final HttpRequest request = requestBuilder .copy ()
245218 .POST (HttpRequest .BodyPublishers .ofString (json ))
246219 .uri (uri )
@@ -256,15 +229,7 @@ private Mono<Void> sentPost(final Object msg,
256229 if (response .statusCode () == 405 || response .statusCode () == 404 ) {
257230 LOGGER .warn ("Operation not allowed, falling back to SSE" );
258231 fallbackToSse .set (true );
259- if (msg instanceof McpSchema .JSONRPCMessage message ) {
260- return sseClientTransport .sendMessage (message );
261- }
262-
263- if (msg instanceof List <?> list ) {
264- @ SuppressWarnings ("unchecked" )
265- final List <McpSchema .JSONRPCMessage > messages = (List <McpSchema .JSONRPCMessage >) list ;
266- return Flux .fromIterable (messages ).flatMap (this ::sendMessage ).then ();
267- }
232+ return fallbackToSse (message );
268233 }
269234
270235 if (response .statusCode () >= 400 ) {
@@ -274,18 +239,28 @@ private Mono<Void> sentPost(final Object msg,
274239
275240 return handleStreamingResponse (response , handler );
276241 });
242+ }).onErrorResume (e -> {
243+ LOGGER .error ("Streamable transport sendMessages error" , e );
244+ return Mono .error (e );
277245 });
278246
279247 }
280248
281- private Mono <String > serializeJson (final Object input ) {
249+ private Mono <Void > fallbackToSse (final McpSchema .JSONRPCMessage msg ) {
250+ if (msg instanceof McpSchema .JSONRPCBatchRequest batch ) {
251+ return Flux .fromIterable (batch .items ()).flatMap (sseClientTransport ::sendMessage ).then ();
252+ }
253+
254+ if (msg instanceof McpSchema .JSONRPCBatchResponse batch ) {
255+ return Flux .fromIterable (batch .items ()).flatMap (sseClientTransport ::sendMessage ).then ();
256+ }
257+
258+ return sseClientTransport .sendMessage (msg );
259+ }
260+
261+ private Mono <String > serializeJson (final McpSchema .JSONRPCMessage msg ) {
282262 try {
283- if (input instanceof McpSchema .JSONRPCMessage || input instanceof List ) {
284- return Mono .just (objectMapper .writeValueAsString (input ));
285- }
286- else {
287- return Mono .error (new IllegalArgumentException ("Unsupported message type for serialization" ));
288- }
263+ return Mono .just (objectMapper .writeValueAsString (msg ));
289264 }
290265 catch (IOException e ) {
291266 LOGGER .error ("Error serializing JSON-RPC message" , e );
@@ -313,9 +288,15 @@ else if (contentType.contains("application/json")) {
313288 private Mono <Void > handleSingleJson (final HttpResponse <InputStream > response ,
314289 final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
315290 return Mono .fromCallable (() -> {
291+ try {
316292 final McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper ,
317293 new String (response .body ().readAllBytes (), StandardCharsets .UTF_8 ));
318294 return handler .apply (Mono .just (msg ));
295+ }
296+ catch (IOException e ) {
297+ LOGGER .error ("Error processing JSON response" , e );
298+ return Mono .error (e );
299+ }
319300 }).flatMap (Function .identity ()).then ();
320301 }
321302
@@ -328,7 +309,7 @@ private Mono<Void> handleJsonStream(final HttpResponse<InputStream> response,
328309 }
329310 catch (IOException e ) {
330311 LOGGER .error ("Error processing JSON line" , e );
331- return Mono .empty ( );
312+ return Mono .error ( e );
332313 }
333314 }).then ();
334315 }
@@ -347,7 +328,7 @@ private Mono<Void> handleSseStream(final HttpResponse<InputStream> response,
347328 if (line .startsWith ("event: " ))
348329 event = line .substring (7 ).trim ();
349330 else if (line .startsWith ("data: " ))
350- data += line .substring (6 ). trim () + "\n " ;
331+ data += line .substring (6 ) + "\n " ;
351332 else if (line .startsWith ("id: " ))
352333 id = line .substring (4 ).trim ();
353334 }
@@ -359,34 +340,35 @@ else if (line.startsWith("id: "))
359340 return new FlowSseClient .SseEvent (event , data , id );
360341 })
361342 .filter (sseEvent -> "message" .equals (sseEvent .type ()))
362- .doOnNext (sseEvent -> {
363- lastEventId . set ( sseEvent .id () );
343+ .concatMap (sseEvent -> {
344+ String rawData = sseEvent .data (). trim ( );
364345 try {
365- String rawData = sseEvent .data ().trim ();
366346 JsonNode node = objectMapper .readTree (rawData );
367-
347+ List < McpSchema . JSONRPCMessage > messages = new ArrayList <>();
368348 if (node .isArray ()) {
369349 for (JsonNode item : node ) {
370- String rawMessage = objectMapper .writeValueAsString (item );
371- McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper ,
372- rawMessage );
373- handler .apply (Mono .just (msg )).subscribe ();
350+ messages .add (McpSchema .deserializeJsonRpcMessage (objectMapper , item .toString ()));
374351 }
375- }
376- else if (node .isObject ()) {
377- String rawMessage = objectMapper .writeValueAsString (node );
378- McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper , rawMessage );
379- handler .apply (Mono .just (msg )).subscribe ();
380- }
381- else {
352+ } else if (node .isObject ()) {
353+ messages .add (McpSchema .deserializeJsonRpcMessage (objectMapper , node .toString ()));
354+ } else {
355+ String warning = "Unexpected JSON in SSE data: " + rawData ;
382356 LOGGER .warn ("Unexpected JSON in SSE data: {}" , rawData );
357+ return Mono .error (new IllegalArgumentException (warning ));
383358 }
359+
360+ return Flux .fromIterable (messages )
361+ .concatMap (msg -> handler .apply (Mono .just (msg )))
362+ .then (Mono .fromRunnable (() -> {
363+ if (!sseEvent .id ().isEmpty ()) {
364+ lastEventId .set (sseEvent .id ());
365+ }
366+ }));
367+ } catch (IOException e ) {
368+ LOGGER .error ("Error parsing SSE JSON: {}" , rawData , e );
369+ return Mono .error (e );
384370 }
385- catch (IOException e ) {
386- LOGGER .error ("Error processing SSE event: {}" , sseEvent .data (), e );
387- }
388- })
389- .then ();
371+ }).then ();
390372 }
391373
392374 @ Override
0 commit comments