Skip to content

Commit 4ecf709

Browse files
committed
feat(client): adds StreamableHttpClientTransport
1 parent 41c6bd9 commit 4ecf709

File tree

1 file changed

+319
-0
lines changed

1 file changed

+319
-0
lines changed
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.client.transport;
6+
7+
import java.io.BufferedReader;
8+
import java.io.IOException;
9+
import java.io.InputStream;
10+
import java.io.InputStreamReader;
11+
import java.net.URI;
12+
import java.net.http.HttpClient;
13+
import java.net.http.HttpRequest;
14+
import java.net.http.HttpResponse;
15+
import java.nio.charset.StandardCharsets;
16+
import java.time.Duration;
17+
import java.util.concurrent.atomic.AtomicBoolean;
18+
import java.util.concurrent.atomic.AtomicReference;
19+
import java.util.function.Consumer;
20+
import java.util.function.Function;
21+
22+
import org.slf4j.Logger;
23+
import org.slf4j.LoggerFactory;
24+
25+
import com.fasterxml.jackson.core.type.TypeReference;
26+
import com.fasterxml.jackson.databind.ObjectMapper;
27+
28+
import io.modelcontextprotocol.spec.McpClientTransport;
29+
import io.modelcontextprotocol.spec.McpSchema;
30+
import io.modelcontextprotocol.util.Assert;
31+
import reactor.core.publisher.Flux;
32+
import reactor.core.publisher.Mono;
33+
import reactor.util.retry.Retry;
34+
35+
/**
36+
* A transport implementation for the Model Context Protocol (MCP) using JSON streaming.
37+
*
38+
* @author Aliaksei Darafeyeu
39+
*/
40+
public class StreamableHttpClientTransport implements McpClientTransport {
41+
42+
private static final Logger LOGGER = LoggerFactory.getLogger(StreamableHttpClientTransport.class);
43+
44+
private final HttpClientSseClientTransport sseClientTransport;
45+
46+
private final HttpClient httpClient;
47+
48+
private final HttpRequest.Builder requestBuilder;
49+
50+
private final ObjectMapper objectMapper;
51+
52+
private final URI uri;
53+
54+
private final AtomicReference<TransportState> state = new AtomicReference<>(TransportState.DISCONNECTED);
55+
56+
private final AtomicReference<String> lastEventId = new AtomicReference<>();
57+
58+
private final AtomicBoolean fallbackToSse = new AtomicBoolean(false);
59+
60+
StreamableHttpClientTransport(final HttpClient httpClient, final HttpRequest.Builder requestBuilder,
61+
final ObjectMapper objectMapper, final String baseUri, final String endpoint) {
62+
this.httpClient = httpClient;
63+
this.requestBuilder = requestBuilder;
64+
this.objectMapper = objectMapper;
65+
this.uri = URI.create(baseUri + endpoint);
66+
this.sseClientTransport = HttpClientSseClientTransport.builder(baseUri).build();
67+
}
68+
69+
/**
70+
* Creates a new StreamableHttpClientTransport instance with the specified URI.
71+
* @param uri the URI to connect to
72+
* @return a new Builder instance
73+
*/
74+
public static Builder builder(final String uri) {
75+
return new Builder().withBaseUri(uri);
76+
}
77+
78+
/**
79+
* The state of the Transport connection.
80+
*/
81+
public enum TransportState {
82+
83+
DISCONNECTED, CONNECTING, CONNECTED, CLOSED
84+
85+
}
86+
87+
/**
88+
* A builder for creating instances of WebSocketClientTransport.
89+
*/
90+
public static class Builder {
91+
92+
private final HttpClient.Builder clientBuilder = HttpClient.newBuilder()
93+
.version(HttpClient.Version.HTTP_1_1)
94+
.connectTimeout(Duration.ofSeconds(10));
95+
96+
private final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
97+
.header("Accept", "application/json, text/event-stream");
98+
99+
private ObjectMapper objectMapper;
100+
101+
private String baseUri;
102+
103+
private String endpoint = "/mcp";
104+
105+
public Builder withCustomizeClient(final Consumer<HttpClient.Builder> clientCustomizer) {
106+
Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
107+
clientCustomizer.accept(clientBuilder);
108+
return this;
109+
}
110+
111+
public Builder withCustomizeRequest(final Consumer<HttpRequest.Builder> requestCustomizer) {
112+
Assert.notNull(requestCustomizer, "requestCustomizer must not be null");
113+
requestCustomizer.accept(requestBuilder);
114+
return this;
115+
}
116+
117+
public Builder withObjectMapper(final ObjectMapper objectMapper) {
118+
Assert.notNull(objectMapper, "objectMapper must not be null");
119+
this.objectMapper = objectMapper;
120+
return this;
121+
}
122+
123+
public Builder withBaseUri(final String baseUri) {
124+
Assert.hasText(baseUri, "baseUri must not be empty");
125+
this.baseUri = baseUri;
126+
return this;
127+
}
128+
129+
public Builder withEndpoint(final String endpoint) {
130+
Assert.hasText(endpoint, "endpoint must not be empty");
131+
this.endpoint = endpoint;
132+
return this;
133+
}
134+
135+
public StreamableHttpClientTransport build() {
136+
return new StreamableHttpClientTransport(clientBuilder.build(), requestBuilder, objectMapper, baseUri,
137+
endpoint);
138+
}
139+
140+
}
141+
142+
@Override
143+
public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
144+
if (fallbackToSse.get()) {
145+
return sseClientTransport.connect(handler);
146+
}
147+
148+
if (!state.compareAndSet(TransportState.DISCONNECTED, TransportState.CONNECTING)) {
149+
return Mono.error(new IllegalStateException("Already connected or connecting"));
150+
}
151+
152+
return sendInitialHandshake().then(Mono.defer(() -> Mono
153+
.fromFuture(() -> httpClient.sendAsync(requestBuilder.build(), HttpResponse.BodyHandlers.ofInputStream()))
154+
.flatMap(response -> handleStreamingResponse(handler, response))
155+
.retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> err instanceof IllegalStateException))
156+
.doOnSuccess(v -> state.set(TransportState.CONNECTED))
157+
.doOnTerminate(() -> state.set(TransportState.CLOSED))
158+
.onErrorResume(e -> {
159+
state.set(TransportState.DISCONNECTED);
160+
LOGGER.error("Failed to connect", e);
161+
return Mono.error(e);
162+
}))).onErrorResume(e -> {
163+
if (e instanceof UnsupportedOperationException) {
164+
LOGGER.warn("Streamable transport failed, falling back to SSE.", e);
165+
fallbackToSse.set(true);
166+
return sseClientTransport.connect(handler);
167+
}
168+
return Mono.error(e);
169+
});
170+
171+
}
172+
173+
@Override
174+
public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
175+
if (state.get() == TransportState.CLOSED) {
176+
return Mono.empty();
177+
}
178+
179+
if (fallbackToSse.get()) {
180+
return sseClientTransport.sendMessage(message);
181+
}
182+
183+
if (state.get() == TransportState.DISCONNECTED) {
184+
state.set(TransportState.CONNECTING);
185+
186+
return sendInitialHandshake().doOnSuccess(v -> state.set(TransportState.CONNECTED)).onErrorResume(e -> {
187+
if (e instanceof UnsupportedOperationException) {
188+
LOGGER.warn("Streamable transport failed, falling back to SSE.", e);
189+
fallbackToSse.set(true);
190+
return Mono.empty();
191+
}
192+
return Mono.error(e);
193+
}).then(sendMessage(message));
194+
}
195+
196+
try {
197+
String json = objectMapper.writeValueAsString(message);
198+
HttpRequest request = requestBuilder.copy().POST(HttpRequest.BodyPublishers.ofString(json)).build();
199+
return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()))
200+
.flatMap(response -> handleStreamingResponse(msg -> msg, response))
201+
.then();
202+
}
203+
catch (Exception e) {
204+
return Mono.error(e);
205+
}
206+
}
207+
208+
private Mono<Void> sendInitialHandshake() {
209+
try {
210+
String json = objectMapper.writeValueAsString(new McpSchema.InitializeRequest("2025-03-26", null, null));
211+
HttpRequest req = requestBuilder.copy().uri(uri).POST(HttpRequest.BodyPublishers.ofString(json)).build();
212+
return Mono.fromFuture(httpClient.sendAsync(req, HttpResponse.BodyHandlers.discarding()))
213+
.flatMap(response -> {
214+
int code = response.statusCode();
215+
if (code == 200) {
216+
return Mono.empty();
217+
}
218+
else if (code >= 400 && code < 500) {
219+
return Mono.error(new UnsupportedOperationException("Client error: " + code));
220+
}
221+
else {
222+
return Mono.error(new IOException("Unexpected status code: " + code));
223+
}
224+
})
225+
.then();
226+
}
227+
catch (IOException e) {
228+
return Mono.error(e);
229+
}
230+
}
231+
232+
private Mono<Void> handleStreamingResponse(
233+
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler,
234+
final HttpResponse<InputStream> response) {
235+
String contentType = response.headers().firstValue("Content-Type").orElse("");
236+
if (contentType.contains("application/json-seq")) {
237+
return handleJsonStream(response, handler);
238+
}
239+
else if (contentType.contains("text/event-stream")) {
240+
return handleSseStream(response, handler);
241+
}
242+
else if (contentType.contains("application/json")) {
243+
return handleSingleJson(response, handler);
244+
}
245+
else {
246+
return Mono.error(new UnsupportedOperationException("Unsupported Content-Type: " + contentType));
247+
}
248+
}
249+
250+
private Mono<Void> handleSingleJson(final HttpResponse<InputStream> response,
251+
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
252+
return Mono.fromCallable(() -> {
253+
McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
254+
new String(response.body().readAllBytes(), StandardCharsets.UTF_8));
255+
return handler.apply(Mono.just(msg));
256+
}).flatMap(Function.identity()).then();
257+
}
258+
259+
private Mono<Void> handleJsonStream(final HttpResponse<InputStream> response,
260+
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
261+
return Flux.fromStream(new BufferedReader(new InputStreamReader(response.body())).lines()).flatMap(jsonLine -> {
262+
try {
263+
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, jsonLine);
264+
return handler.apply(Mono.just(message));
265+
}
266+
catch (IOException e) {
267+
LOGGER.error("Error processing JSON line", e);
268+
return Mono.empty();
269+
}
270+
}).then();
271+
}
272+
273+
private Mono<Void> handleSseStream(final HttpResponse<InputStream> response,
274+
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
275+
return Flux.fromStream(new BufferedReader(new InputStreamReader(response.body())).lines())
276+
.scan(new FlowSseClient.SseEvent("", "", ""), (acc, line) -> {
277+
String event = acc.type();
278+
String data = acc.data();
279+
String id = acc.id();
280+
281+
if (line.startsWith("event: "))
282+
event = line.substring(7).trim();
283+
else if (line.startsWith("data: "))
284+
data = line.substring(6).trim();
285+
else if (line.startsWith("id: "))
286+
id = line.substring(4).trim();
287+
288+
return new FlowSseClient.SseEvent(event, data, id);
289+
})
290+
.filter(sseEvent -> "message".equals(sseEvent.type()))
291+
.doOnNext(sseEvent -> {
292+
lastEventId.set(sseEvent.id());
293+
try {
294+
McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, sseEvent.data());
295+
handler.apply(Mono.just(msg)).subscribe();
296+
}
297+
catch (IOException e) {
298+
LOGGER.error("Error processing SSE event", e);
299+
}
300+
})
301+
.then();
302+
}
303+
304+
@Override
305+
public Mono<Void> closeGracefully() {
306+
state.set(TransportState.CLOSED);
307+
return Mono.empty();
308+
}
309+
310+
@Override
311+
public <T> T unmarshalFrom(final Object data, final TypeReference<T> typeRef) {
312+
return objectMapper.convertValue(data, typeRef);
313+
}
314+
315+
public TransportState getState() {
316+
return state.get();
317+
}
318+
319+
}

0 commit comments

Comments
 (0)