Skip to content

Commit 7b7fa87

Browse files
committed
More logs, resilience tests improved
1 parent 84f7ae0 commit 7b7fa87

File tree

7 files changed

+154
-97
lines changed

7 files changed

+154
-97
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java

Lines changed: 103 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,31 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
6969
return Mono.deferContextual(ctx -> {
7070
this.handler.set(handler);
7171
if (openConnectionOnStartup) {
72+
logger.debug("Eagerly opening connection on startup");
7273
this.reconnect(null, ctx);
7374
}
7475
return Mono.empty();
7576
});
7677
}
7778

7879
@Override
79-
public void handleException(Consumer<Throwable> handler) {
80+
public void registerExceptionHandler(Consumer<Throwable> handler) {
81+
logger.debug("Exception handler registered");
8082
this.exceptionHandler.set(handler);
8183
}
8284

85+
private void handleException(Throwable t) {
86+
logger.debug("Handling exception for session {}", activeSession.get().sessionId(), t);
87+
Consumer<Throwable> handler = this.exceptionHandler.get();
88+
if (handler != null) {
89+
handler.accept(t);
90+
}
91+
}
92+
8393
@Override
8494
public Mono<Void> closeGracefully() {
8595
return Mono.defer(() -> {
96+
logger.debug("Graceful close triggered");
8697
McpTransportSession currentSession = this.activeSession.get();
8798
if (currentSession != null) {
8899
return currentSession.closeGracefully();
@@ -92,6 +103,12 @@ public Mono<Void> closeGracefully() {
92103
}
93104

94105
private void reconnect(McpStream stream, ContextView ctx) {
106+
if (stream != null) {
107+
logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId());
108+
}
109+
else {
110+
logger.debug("Reconnecting with no prior stream");
111+
}
95112
// Here we attempt to initialize the client.
96113
// In case the server supports SSE, we will establish a long-running session
97114
// here and
@@ -113,10 +130,11 @@ private void reconnect(McpStream stream, ContextView ctx) {
113130
.exchangeToFlux(response -> {
114131
// Per spec, we are not checking whether it's 2xx, but only if the
115132
// Accept header is proper.
116-
if (response.headers().contentType().isPresent()
133+
if (response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent()
117134
&& response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
118135

119136
McpStream sessionStream = stream != null ? stream : new McpStream(this.resumableStreams);
137+
logger.debug("Established stream {}", sessionStream.streamId());
120138

121139
Flux<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>> idWithMessages = response
122140
.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {
@@ -126,14 +144,14 @@ private void reconnect(McpStream stream, ContextView ctx) {
126144
return sessionStream.consumeSseStream(idWithMessages);
127145
}
128146
else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) {
129-
logger.info("The server does not support SSE streams, using request-response mode.");
147+
logger.debug("The server does not support SSE streams, using request-response mode.");
130148
return Flux.empty();
131149
}
132150
else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) {
133-
logger.info("Session {} was not found on the MCP server", transportSession.sessionId());
151+
logger.warn("Session {} was not found on the MCP server", transportSession.sessionId());
134152

135153
McpSessionNotFoundException notFoundException = new McpSessionNotFoundException(
136-
"Session " + transportSession.sessionId() + " not found");
154+
transportSession.sessionId());
137155
// inform the stream/connection subscriber
138156
return Flux.error(notFoundException);
139157
}
@@ -143,8 +161,9 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) {
143161
}).flux();
144162
}
145163
})
146-
.doOnError(e -> {
147-
this.exceptionHandler.get().accept(e);
164+
.onErrorResume(t -> {
165+
this.handleException(t);
166+
return Flux.empty();
148167
})
149168
.doFinally(s -> {
150169
Disposable ref = disposableRef.getAndSet(null);
@@ -161,7 +180,7 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) {
161180
@Override
162181
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
163182
return Mono.create(sink -> {
164-
System.out.println("Sending message " + message);
183+
logger.debug("Sending message {}", message);
165184
// Here we attempt to initialize the client.
166185
// In case the server supports SSE, we will establish a long-running session
167186
// here and
@@ -182,8 +201,9 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
182201
.exchangeToFlux(response -> {
183202
if (transportSession.markInitialized()) {
184203
if (!response.headers().header("mcp-session-id").isEmpty()) {
185-
transportSession
186-
.setSessionId(response.headers().asHttpHeaders().getFirst("mcp-session-id"));
204+
String sessionId = response.headers().asHttpHeaders().getFirst("mcp-session-id");
205+
logger.debug("Established session with id {}", sessionId);
206+
transportSession.setSessionId(sessionId);
187207
// Once we have a session, we try to open an async stream for
188208
// the server to send notifications and requests out-of-band.
189209
reconnect(null, sink.contextView());
@@ -193,12 +213,72 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
193213
// The spec mentions only ACCEPTED, but the existing SDKs can return
194214
// 200 OK for notifications
195215
// if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) {
196-
if (!response.statusCode().is2xxSuccessful()) {
216+
if (response.statusCode().is2xxSuccessful()) {
217+
// Existing SDKs consume notifications with no response body nor
218+
// content type
219+
if (response.headers().contentType().isEmpty()) {
220+
logger.trace("Message was successfuly sent via POST for session {}",
221+
transportSession.sessionId());
222+
// signal the caller that the message was successfully
223+
// delivered
224+
sink.success();
225+
// communicate to downstream there is no streamed data coming
226+
return Flux.empty();
227+
}
228+
229+
MediaType contentType = response.headers().contentType().get();
230+
231+
if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
232+
// communicate to caller that the message was delivered
233+
sink.success();
234+
235+
// starting a stream
236+
McpStream sessionStream = new McpStream(this.resumableStreams);
237+
238+
logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(),
239+
transportSession.sessionId());
240+
241+
Flux<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>> idWithMessages = response
242+
.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {
243+
})
244+
.map(this::parse);
245+
246+
return sessionStream.consumeSseStream(idWithMessages);
247+
}
248+
else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
249+
logger.trace("Received response to POST for session {}", transportSession.sessionId());
250+
251+
// communicate to caller the message was delivered
252+
sink.success();
253+
254+
// provide the response body as a stream of a single response
255+
// to consume
256+
return response.bodyToMono(
257+
String.class).<Iterable<McpSchema.JSONRPCMessage>>handle((responseMessage, s) -> {
258+
try {
259+
McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema
260+
.deserializeJsonRpcMessage(objectMapper, responseMessage);
261+
s.next(List.of(jsonRpcResponse));
262+
}
263+
catch (IOException e) {
264+
s.error(e);
265+
}
266+
})
267+
.flatMapIterable(Function.identity());
268+
}
269+
else {
270+
logger.warn("Unknown media type {} returned for POST in session {}", contentType,
271+
transportSession.sessionId());
272+
sink.error(new RuntimeException("Unknown media type returned: " + contentType));
273+
return Flux.empty();
274+
}
275+
}
276+
else {
197277
if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) {
198-
logger.info("Session {} was not found on the MCP server", transportSession.sessionId());
278+
logger.warn("Session {} was not found on the MCP server", transportSession.sessionId());
199279

200280
McpSessionNotFoundException notFoundException = new McpSessionNotFoundException(
201-
"Session " + transportSession.sessionId() + " not found");
281+
transportSession.sessionId());
202282
// inform the caller of sendMessage
203283
sink.error(notFoundException);
204284
// inform the stream/connection subscriber
@@ -208,58 +288,14 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
208288
sink.error(new RuntimeException("Sending request failed", e));
209289
}).flux();
210290
}
211-
212-
// Existing SDKs consume notifications with no response body nor
213-
// content type
214-
if (response.headers().contentType().isEmpty()) {
215-
sink.success();
216-
return Flux.empty();
217-
// return
218-
// response.<McpSchema.JSONRPCMessage>createError().doOnError(e ->
219-
// {
220-
//// sink.error(new RuntimeException("Response has no content
221-
// type"));
222-
// }).flux();
223-
}
224-
225-
MediaType contentType = response.headers().contentType().get();
226-
227-
if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
228-
sink.success();
229-
McpStream sessionStream = new McpStream(this.resumableStreams);
230-
231-
Flux<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>> idWithMessages = response
232-
.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {
233-
})
234-
.map(this::parse);
235-
236-
return sessionStream.consumeSseStream(idWithMessages);
237-
}
238-
else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
239-
sink.success();
240-
// return response.bodyToMono(new
241-
// ParameterizedTypeReference<Iterable<McpSchema.JSONRPCMessage>>()
242-
// {});
243-
return response.bodyToMono(
244-
String.class).<Iterable<McpSchema.JSONRPCMessage>>handle((responseMessage, s) -> {
245-
try {
246-
McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema
247-
.deserializeJsonRpcMessage(objectMapper, responseMessage);
248-
s.next(List.of(jsonRpcResponse));
249-
}
250-
catch (IOException e) {
251-
s.error(e);
252-
}
253-
})
254-
.flatMapIterable(Function.identity());
255-
}
256-
else {
257-
sink.error(new RuntimeException("Unknown media type"));
258-
return Flux.empty();
259-
}
260291
})
261292
.map(Mono::just)
262293
.flatMap(this.handler.get())
294+
.onErrorResume(t -> {
295+
this.handleException(t);
296+
sink.error(t);
297+
return Flux.empty();
298+
})
263299
.doFinally(s -> {
264300
Disposable ref = disposableRef.getAndSet(null);
265301
if (ref != null) {
@@ -281,7 +317,7 @@ public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
281317
private Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>> parse(ServerSentEvent<String> event) {
282318
if (MESSAGE_EVENT_TYPE.equals(event.event())) {
283319
try {
284-
// TODO: support batching
320+
// TODO: support batching?
285321
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data());
286322
return Tuples.of(Optional.ofNullable(event.id()), List.of(message));
287323
}
@@ -313,6 +349,10 @@ String lastId() {
313349
return this.lastId.get();
314350
}
315351

352+
long streamId() {
353+
return this.streamId;
354+
}
355+
316356
Flux<McpSchema.JSONRPCMessage> consumeSseStream(
317357
Publisher<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>> eventStream) {
318358
return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> {

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
import com.fasterxml.jackson.databind.ObjectMapper;
44
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
55
import io.modelcontextprotocol.spec.McpClientTransport;
6-
import org.junit.jupiter.api.AfterAll;
7-
import org.junit.jupiter.api.BeforeAll;
86
import org.springframework.web.reactive.function.client.WebClient;
9-
import org.testcontainers.containers.GenericContainer;
10-
import org.testcontainers.containers.wait.strategy.Wait;
117

128
public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests {
139

mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
</appender>
1010

1111
<!-- Main MCP package -->
12-
<logger name="org.springframework.ai.mcp" level="INFO"/>
12+
<logger name="org.springframework.ai.mcp" level="DEBUG"/>
1313

1414
<!-- Client packages -->
15-
<logger name="org.springframework.ai.mcp.client" level="INFO"/>
15+
<logger name="org.springframework.ai.mcp.client" level="DEBUG"/>
1616

1717
<!-- Spec package -->
18-
<logger name="org.springframework.ai.mcp.spec" level="INFO"/>
18+
<logger name="org.springframework.ai.mcp.spec" level="DEBUG"/>
1919

2020

2121
<!-- Root logger -->

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import org.junit.jupiter.api.AfterAll;
1111
import org.junit.jupiter.api.BeforeAll;
1212
import org.junit.jupiter.api.Test;
13+
import org.slf4j.Logger;
14+
import org.slf4j.LoggerFactory;
1315
import org.testcontainers.containers.GenericContainer;
1416
import org.testcontainers.containers.Network;
1517
import org.testcontainers.containers.ToxiproxyContainer;
@@ -27,6 +29,8 @@
2729

2830
public abstract class AbstractMcpAsyncClientResiliencyTests {
2931

32+
private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class);
33+
3034
static Network network = Network.newNetwork();
3135
static String host = "http://localhost:3001";
3236

@@ -65,6 +69,37 @@ public abstract class AbstractMcpAsyncClientResiliencyTests {
6569
host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy;
6670
}
6771

72+
void disconnect() {
73+
long start = System.nanoTime();
74+
try {
75+
// disconnect
76+
// proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM",
77+
// ToxicDirection.DOWNSTREAM, 0);
78+
// proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM",
79+
// ToxicDirection.UPSTREAM, 0);
80+
proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0);
81+
proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0);
82+
logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis());
83+
}
84+
catch (IOException e) {
85+
throw new RuntimeException("Failed to disconnect", e);
86+
}
87+
}
88+
89+
void reconnect() {
90+
long start = System.nanoTime();
91+
try {
92+
proxy.toxics().get("RESET_UPSTREAM").remove();
93+
proxy.toxics().get("RESET_DOWNSTREAM").remove();
94+
// proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove();
95+
// proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove();
96+
logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis());
97+
}
98+
catch (IOException e) {
99+
throw new RuntimeException("Failed to reconnect", e);
100+
}
101+
}
102+
68103
abstract McpClientTransport createMcpTransport();
69104

70105
protected Duration getRequestTimeout() {
@@ -112,29 +147,15 @@ void withClient(McpClientTransport transport, Function<McpClient.AsyncSpec, McpC
112147
@Test
113148
void testPing() {
114149
withClient(createMcpTransport(), mcpAsyncClient -> {
115-
try {
116-
StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete();
117-
118-
// disconnect
119-
// proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM",
120-
// ToxicDirection.DOWNSTREAM, 0);
121-
// proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM",
122-
// ToxicDirection.UPSTREAM, 0);
123-
proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0);
124-
proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0);
125-
126-
StepVerifier.create(mcpAsyncClient.ping()).expectError().verify();
127-
128-
proxy.toxics().get("RESET_UPSTREAM").remove();
129-
proxy.toxics().get("RESET_DOWNSTREAM").remove();
130-
// proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove();
131-
// proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove();
132-
133-
StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete();
134-
}
135-
catch (IOException e) {
136-
throw new RuntimeException(e);
137-
}
150+
StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete();
151+
152+
disconnect();
153+
154+
StepVerifier.create(mcpAsyncClient.ping()).expectError().verify();
155+
156+
reconnect();
157+
158+
StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete();
138159
});
139160
}
140161

0 commit comments

Comments
 (0)