|
14 | 14 | import java.net.http.HttpResponse; |
15 | 15 | import java.nio.charset.StandardCharsets; |
16 | 16 | import java.time.Duration; |
| 17 | +import java.util.List; |
17 | 18 | import java.util.concurrent.atomic.AtomicBoolean; |
18 | 19 | import java.util.concurrent.atomic.AtomicReference; |
19 | 20 | import java.util.function.Consumer; |
@@ -195,10 +196,43 @@ public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) { |
195 | 196 |
|
196 | 197 | try { |
197 | 198 | String json = objectMapper.writeValueAsString(message); |
198 | | - HttpRequest request = requestBuilder.copy().POST(HttpRequest.BodyPublishers.ofString(json)).build(); |
199 | | - return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream())) |
200 | | - .flatMap(response -> handleStreamingResponse(msg -> msg, response)) |
201 | | - .then(); |
| 199 | + return sentPost(json); |
| 200 | + } |
| 201 | + catch (Exception e) { |
| 202 | + return Mono.error(e); |
| 203 | + } |
| 204 | + } |
| 205 | + |
| 206 | + /** |
| 207 | + * Sends a list of messages to the server. |
| 208 | + * @param messages the list of messages to send |
| 209 | + * @return a Mono that completes when all messages have been sent |
| 210 | + */ |
| 211 | + public Mono<Void> sendMessages(final List<McpSchema.JSONRPCMessage> messages) { |
| 212 | + if (state.get() == TransportState.CLOSED) { |
| 213 | + return Mono.empty(); |
| 214 | + } |
| 215 | + |
| 216 | + if (fallbackToSse.get()) { |
| 217 | + return Flux.fromIterable(messages).flatMap(this::sendMessage).then(); |
| 218 | + } |
| 219 | + |
| 220 | + if (state.get() == TransportState.DISCONNECTED) { |
| 221 | + state.set(TransportState.CONNECTING); |
| 222 | + |
| 223 | + return sendInitialHandshake().doOnSuccess(v -> state.set(TransportState.CONNECTED)).onErrorResume(e -> { |
| 224 | + if (e instanceof UnsupportedOperationException) { |
| 225 | + LOGGER.warn("Streamable transport failed, falling back to SSE.", e); |
| 226 | + fallbackToSse.set(true); |
| 227 | + return Mono.empty(); |
| 228 | + } |
| 229 | + return Mono.error(e); |
| 230 | + }).then(sendMessages(messages)); |
| 231 | + } |
| 232 | + |
| 233 | + try { |
| 234 | + String json = objectMapper.writeValueAsString(messages); |
| 235 | + return sentPost(json); |
202 | 236 | } |
203 | 237 | catch (Exception e) { |
204 | 238 | return Mono.error(e); |
@@ -229,6 +263,13 @@ else if (code >= 400 && code < 500) { |
229 | 263 | } |
230 | 264 | } |
231 | 265 |
|
| 266 | + private Mono<Void> sentPost(String json) { |
| 267 | + HttpRequest request = requestBuilder.copy().POST(HttpRequest.BodyPublishers.ofString(json)).build(); |
| 268 | + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream())) |
| 269 | + .flatMap(response -> handleStreamingResponse(msg -> msg, response)) |
| 270 | + .then(); |
| 271 | + } |
| 272 | + |
232 | 273 | private Mono<Void> handleStreamingResponse( |
233 | 274 | final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler, |
234 | 275 | final HttpResponse<InputStream> response) { |
|
0 commit comments