Skip to content

Commit c8c21ba

Browse files
committed
feat(client): adds tests and code cleanup
1 parent c9bd70b commit c8c21ba

File tree

3 files changed

+186
-83
lines changed

3 files changed

+186
-83
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java

Lines changed: 98 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ public class StreamableHttpClientTransport implements McpClientTransport {
6060
private final AtomicBoolean fallbackToSse = new AtomicBoolean(false);
6161

6262
StreamableHttpClientTransport(final HttpClient httpClient, final HttpRequest.Builder requestBuilder,
63-
final ObjectMapper objectMapper, final String baseUri, final String endpoint) {
63+
final ObjectMapper objectMapper, final String baseUri, final String endpoint,
64+
final HttpClientSseClientTransport sseClientTransport) {
6465
this.httpClient = httpClient;
6566
this.requestBuilder = requestBuilder;
6667
this.objectMapper = objectMapper;
6768
this.uri = URI.create(baseUri + endpoint);
68-
this.sseClientTransport = HttpClientSseClientTransport.builder(baseUri).build();
69+
this.sseClientTransport = sseClientTransport;
6970
}
7071

7172
/**
@@ -98,21 +99,27 @@ public static class Builder {
9899
private final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
99100
.header("Accept", "application/json, text/event-stream");
100101

101-
private ObjectMapper objectMapper;
102+
private ObjectMapper objectMapper = new ObjectMapper();
102103

103104
private String baseUri;
104105

105106
private String endpoint = "/mcp";
106107

108+
private Consumer<HttpClient.Builder> clientCustomizer;
109+
110+
private Consumer<HttpRequest.Builder> requestCustomizer;
111+
107112
public Builder withCustomizeClient(final Consumer<HttpClient.Builder> clientCustomizer) {
108113
Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
109114
clientCustomizer.accept(clientBuilder);
115+
this.clientCustomizer = clientCustomizer;
110116
return this;
111117
}
112118

113119
public Builder withCustomizeRequest(final Consumer<HttpRequest.Builder> requestCustomizer) {
114120
Assert.notNull(requestCustomizer, "requestCustomizer must not be null");
115121
requestCustomizer.accept(requestBuilder);
122+
this.requestCustomizer = requestCustomizer;
116123
return this;
117124
}
118125

@@ -135,8 +142,22 @@ public Builder withEndpoint(final String endpoint) {
135142
}
136143

137144
public StreamableHttpClientTransport build() {
145+
final HttpClientSseClientTransport.Builder builder = HttpClientSseClientTransport.builder(baseUri)
146+
.objectMapper(objectMapper);
147+
if (clientCustomizer != null) {
148+
builder.customizeClient(clientCustomizer);
149+
}
150+
151+
if (requestCustomizer != null) {
152+
builder.customizeRequest(requestCustomizer);
153+
}
154+
155+
if (!endpoint.equals("/mcp")) {
156+
builder.sseEndpoint(endpoint);
157+
}
158+
138159
return new StreamableHttpClientTransport(clientBuilder.build(), requestBuilder, objectMapper, baseUri,
139-
endpoint);
160+
endpoint, builder.build());
140161
}
141162

142163
}
@@ -151,57 +172,44 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
151172
return Mono.error(new IllegalStateException("Already connected or connecting"));
152173
}
153174

154-
return sendInitialHandshake().then(Mono.defer(() -> Mono
155-
.fromFuture(() -> httpClient.sendAsync(requestBuilder.build(), HttpResponse.BodyHandlers.ofInputStream()))
156-
.flatMap(response -> handleStreamingResponse(handler, response))
175+
return Mono.defer(() -> Mono.fromFuture(() -> {
176+
final HttpRequest.Builder builder = requestBuilder.copy().GET().uri(uri);
177+
final String lastId = lastEventId.get();
178+
if (lastId != null) {
179+
builder.header("Last-Event-ID", lastId);
180+
}
181+
return httpClient.sendAsync(builder.build(), HttpResponse.BodyHandlers.ofInputStream());
182+
}).flatMap(response -> {
183+
if (response.statusCode() == 405 || response.statusCode() == 404) {
184+
LOGGER.warn("Operation not allowed, falling back to SSE");
185+
fallbackToSse.set(true);
186+
return sseClientTransport.connect(handler);
187+
}
188+
return handleStreamingResponse(handler, response);
189+
})
157190
.retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> err instanceof IllegalStateException))
158191
.doOnSuccess(v -> state.set(TransportState.CONNECTED))
159192
.doOnTerminate(() -> state.set(TransportState.CLOSED))
160193
.onErrorResume(e -> {
161-
state.set(TransportState.DISCONNECTED);
162-
LOGGER.error("Failed to connect", e);
194+
LOGGER.error("Streamable transport connection error", e);
163195
return Mono.error(e);
164-
}))).onErrorResume(e -> {
165-
if (e instanceof UnsupportedOperationException) {
166-
LOGGER.warn("Streamable transport failed, falling back to SSE.", e);
167-
fallbackToSse.set(true);
168-
return sseClientTransport.connect(handler);
169-
}
170-
return Mono.error(e);
171-
});
172-
196+
}));
173197
}
174198

175199
@Override
176200
public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
177-
if (state.get() == TransportState.CLOSED) {
178-
return Mono.empty();
179-
}
180-
181201
if (fallbackToSse.get()) {
182202
return sseClientTransport.sendMessage(message);
183203
}
184204

185-
if (state.get() == TransportState.DISCONNECTED) {
186-
state.set(TransportState.CONNECTING);
187-
188-
return sendInitialHandshake().doOnSuccess(v -> state.set(TransportState.CONNECTED)).onErrorResume(e -> {
189-
if (e instanceof UnsupportedOperationException) {
190-
LOGGER.warn("Streamable transport failed, falling back to SSE.", e);
191-
fallbackToSse.set(true);
192-
return Mono.empty();
193-
}
194-
return Mono.error(e);
195-
}).then(sendMessage(message));
205+
if (state.get() == TransportState.CLOSED) {
206+
return Mono.empty();
196207
}
197208

198-
try {
199-
String json = objectMapper.writeValueAsString(message);
200-
return sentPost(json);
201-
}
202-
catch (Exception e) {
209+
return sentPost(message).onErrorResume(e -> {
210+
LOGGER.error("Streamable transport sendMessage error", e);
203211
return Mono.error(e);
204-
}
212+
});
205213
}
206214

207215
/**
@@ -210,71 +218,78 @@ public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
210218
* @return a Mono that completes when all messages have been sent
211219
*/
212220
public Mono<Void> sendMessages(final List<McpSchema.JSONRPCMessage> messages) {
213-
if (state.get() == TransportState.CLOSED) {
214-
return Mono.empty();
215-
}
216-
217221
if (fallbackToSse.get()) {
218222
return Flux.fromIterable(messages).flatMap(this::sendMessage).then();
219223
}
220224

221-
if (state.get() == TransportState.DISCONNECTED) {
222-
state.set(TransportState.CONNECTING);
223-
224-
return sendInitialHandshake().doOnSuccess(v -> state.set(TransportState.CONNECTED)).onErrorResume(e -> {
225-
if (e instanceof UnsupportedOperationException) {
226-
LOGGER.warn("Streamable transport failed, falling back to SSE.", e);
227-
fallbackToSse.set(true);
228-
return Mono.empty();
229-
}
230-
return Mono.error(e);
231-
}).then(sendMessages(messages));
225+
if (state.get() == TransportState.CLOSED) {
226+
return Mono.empty();
232227
}
233228

234-
try {
235-
String json = objectMapper.writeValueAsString(messages);
236-
return sentPost(json);
237-
}
238-
catch (Exception e) {
229+
return sentPost(messages).onErrorResume(e -> {
230+
LOGGER.error("Streamable transport sendMessages error", e);
239231
return Mono.error(e);
240-
}
232+
});
241233
}
242234

243-
private Mono<Void> sendInitialHandshake() {
244-
try {
245-
String json = objectMapper.writeValueAsString(new McpSchema.InitializeRequest("2025-03-26", null, null));
246-
HttpRequest req = requestBuilder.copy().uri(uri).POST(HttpRequest.BodyPublishers.ofString(json)).build();
247-
return Mono.fromFuture(httpClient.sendAsync(req, HttpResponse.BodyHandlers.discarding()))
235+
private Mono<Void> sentPost(final Object msg) {
236+
return serializeJson(msg).flatMap(json -> {
237+
final HttpRequest request = requestBuilder.copy()
238+
.POST(HttpRequest.BodyPublishers.ofString(json))
239+
.uri(uri)
240+
.build();
241+
return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()))
248242
.flatMap(response -> {
249-
int code = response.statusCode();
250-
if (code == 200) {
243+
244+
// If the response is 202 Accepted, there's no body to process
245+
if (response.statusCode() == 202) {
251246
return Mono.empty();
252247
}
253-
else if (code >= 400 && code < 500) {
254-
return Mono.error(new UnsupportedOperationException("Client error: " + code));
248+
249+
if (response.statusCode() == 405 || response.statusCode() == 404) {
250+
LOGGER.warn("Operation not allowed, falling back to SSE");
251+
fallbackToSse.set(true);
252+
if (msg instanceof McpSchema.JSONRPCMessage message) {
253+
return sseClientTransport.sendMessage(message);
254+
}
255+
256+
if (msg instanceof List<?> list) {
257+
@SuppressWarnings("unchecked")
258+
final List<McpSchema.JSONRPCMessage> messages = (List<McpSchema.JSONRPCMessage>) list;
259+
return Flux.fromIterable(messages).flatMap(this::sendMessage).then();
260+
}
255261
}
256-
else {
257-
return Mono.error(new IOException("Unexpected status code: " + code));
262+
263+
if (response.statusCode() >= 400) {
264+
return Mono
265+
.error(new IllegalArgumentException("Unexpected status code: " + response.statusCode()));
258266
}
259-
})
260-
.then();
267+
268+
return handleStreamingResponse(it -> it, response);
269+
});
270+
});
271+
272+
}
273+
274+
private Mono<String> serializeJson(final Object input) {
275+
try {
276+
if (input instanceof McpSchema.JSONRPCMessage || input instanceof List) {
277+
return Mono.just(objectMapper.writeValueAsString(input));
278+
}
279+
else {
280+
return Mono.error(new IllegalArgumentException("Unsupported message type for serialization"));
281+
}
261282
}
262283
catch (IOException e) {
284+
LOGGER.error("Error serializing JSON-RPC message", e);
263285
return Mono.error(e);
264286
}
265287
}
266288

267-
private Mono<Void> sentPost(String json) {
268-
HttpRequest request = requestBuilder.copy().POST(HttpRequest.BodyPublishers.ofString(json)).build();
269-
return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()))
270-
.flatMap(response -> handleStreamingResponse(msg -> msg, response))
271-
.then();
272-
}
273-
274289
private Mono<Void> handleStreamingResponse(
275290
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler,
276291
final HttpResponse<InputStream> response) {
277-
String contentType = response.headers().firstValue("Content-Type").orElse("");
292+
final String contentType = response.headers().firstValue("Content-Type").orElse("");
278293
if (contentType.contains("application/json-seq")) {
279294
return handleJsonStream(response, handler);
280295
}
@@ -292,7 +307,7 @@ else if (contentType.contains("application/json")) {
292307
private Mono<Void> handleSingleJson(final HttpResponse<InputStream> response,
293308
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
294309
return Mono.fromCallable(() -> {
295-
McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
310+
final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
296311
new String(response.body().readAllBytes(), StandardCharsets.UTF_8));
297312
return handler.apply(Mono.just(msg));
298313
}).flatMap(Function.identity()).then();
@@ -302,7 +317,7 @@ private Mono<Void> handleJsonStream(final HttpResponse<InputStream> response,
302317
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
303318
return Flux.fromStream(new BufferedReader(new InputStreamReader(response.body())).lines()).flatMap(jsonLine -> {
304319
try {
305-
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, jsonLine);
320+
final McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, jsonLine);
306321
return handler.apply(Mono.just(message));
307322
}
308323
catch (IOException e) {
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package io.modelcontextprotocol.client;
2+
3+
import org.junit.jupiter.api.Timeout;
4+
import org.testcontainers.containers.GenericContainer;
5+
import org.testcontainers.containers.wait.strategy.Wait;
6+
7+
import io.modelcontextprotocol.client.transport.StreamableHttpClientTransport;
8+
import io.modelcontextprotocol.spec.McpClientTransport;
9+
10+
/**
11+
* Tests for the {@link McpAsyncClient} with {@link StreamableHttpClientTransport}.
12+
*
13+
* @author Aliaksei Darafeyeu
14+
*/
15+
@Timeout(15)
16+
public class StreamableHttpClientTransportAsyncTest extends AbstractMcpAsyncClientTests {
17+
18+
String host = "http://localhost:3003";
19+
20+
// Uses the https://github.com/tzolov/mcp-everything-server-docker-image
21+
@SuppressWarnings("resource")
22+
GenericContainer<?> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1")
23+
.withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
24+
.withExposedPorts(3001)
25+
.waitingFor(Wait.forHttp("/").forStatusCode(404));
26+
27+
@Override
28+
protected McpClientTransport createMcpTransport() {
29+
return StreamableHttpClientTransport.builder(host).build();
30+
}
31+
32+
@Override
33+
protected void onStart() {
34+
container.start();
35+
int port = container.getMappedPort(3001);
36+
host = "http://" + container.getHost() + ":" + port;
37+
}
38+
39+
@Override
40+
protected void onClose() {
41+
container.stop();
42+
}
43+
44+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package io.modelcontextprotocol.client;
2+
3+
import org.junit.jupiter.api.Timeout;
4+
import org.testcontainers.containers.GenericContainer;
5+
import org.testcontainers.containers.wait.strategy.Wait;
6+
7+
import io.modelcontextprotocol.client.transport.StreamableHttpClientTransport;
8+
import io.modelcontextprotocol.spec.McpClientTransport;
9+
10+
/**
11+
* Tests for the {@link McpSyncClient} with {@link StreamableHttpClientTransport}.
12+
*
13+
* @author Aliaksei Darafeyeu
14+
*/
15+
@Timeout(15)
16+
public class StreamableHttpClientTransportSyncTest extends AbstractMcpSyncClientTests {
17+
18+
String host = "http://localhost:3003";
19+
20+
// Uses the https://github.com/tzolov/mcp-everything-server-docker-image
21+
@SuppressWarnings("resource")
22+
GenericContainer<?> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1")
23+
.withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
24+
.withExposedPorts(3001)
25+
.waitingFor(Wait.forHttp("/").forStatusCode(404));
26+
27+
@Override
28+
protected McpClientTransport createMcpTransport() {
29+
return StreamableHttpClientTransport.builder(host).build();
30+
}
31+
32+
@Override
33+
protected void onStart() {
34+
container.start();
35+
int port = container.getMappedPort(3001);
36+
host = "http://" + container.getHost() + ":" + port;
37+
}
38+
39+
@Override
40+
protected void onClose() {
41+
container.stop();
42+
}
43+
44+
}

0 commit comments

Comments
 (0)