Skip to content

Commit ea50838

Browse files
committed
fix: improve streamable HTTP session reinitialization (#459)
Implements the MCP spec guidelines for streamable HTTP (re)initialization: - Server MAY terminate session and MUST respond with HTTP 404 for terminated session IDs - Client MUST start new session when receiving HTTP 404 for requests with session ID Changes: - Replace generic McpError with McpTransportException for transport-layer errors - Only throw McpTransportSessionNotFoundException when session ID is present in request (per spec: 404 with session ID means session terminated, without means general error) - Enhance error messages with more context (status codes, response events) - Use RuntimeException for non-transport specific SSE endpoint failures - Ensure consistent error handling across HTTP client transports - Improve error handling with standard Java exceptions. Replace generic McpError with appropriate standard exceptions: - Use IllegalArgumentException for invalid input parameters - Use IllegalStateException for state-related issues - Use RuntimeException wrapper for initialization failures - Use McpError.builder() with proper error codes for protocol errors Fixes #459 Signed-off-by: Christian Tzolov <[email protected]>
1 parent ed2c35e commit ea50838

25 files changed

+955
-90
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import io.modelcontextprotocol.spec.McpClientTransport;
3232
import io.modelcontextprotocol.spec.McpError;
3333
import io.modelcontextprotocol.spec.McpSchema;
34+
import io.modelcontextprotocol.spec.McpTransportException;
3435
import io.modelcontextprotocol.spec.McpTransportSession;
3536
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
3637
import io.modelcontextprotocol.spec.McpTransportStream;
@@ -70,6 +71,8 @@
7071
*/
7172
public class WebClientStreamableHttpTransport implements McpClientTransport {
7273

74+
private static final String MISSING_SESSION_ID = "[missing_session_id]";
75+
7376
private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class);
7477

7578
private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26;
@@ -221,8 +224,13 @@ else if (isNotAllowed(response)) {
221224
return Flux.empty();
222225
}
223226
else if (isNotFound(response)) {
224-
String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
225-
return mcpSessionNotFoundError(sessionIdRepresentation);
227+
if (transportSession.sessionId().isPresent()) {
228+
String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
229+
return mcpSessionNotFoundError(sessionIdRepresentation);
230+
}
231+
else {
232+
return this.extractError(response, MISSING_SESSION_ID);
233+
}
226234
}
227235
else {
228236
return response.<McpSchema.JSONRPCMessage>createError().doOnError(e -> {
@@ -318,10 +326,10 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
318326
}
319327
}
320328
else {
321-
if (isNotFound(response)) {
329+
if (isNotFound(response) && !sessionRepresentation.equals(MISSING_SESSION_ID)) {
322330
return mcpSessionNotFoundError(sessionRepresentation);
323331
}
324-
return extractError(response, sessionRepresentation);
332+
return this.extractError(response, sessionRepresentation);
325333
}
326334
})
327335
.flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage)))
@@ -362,10 +370,10 @@ private Flux<McpSchema.JSONRPCMessage> extractError(ClientResponse response, Str
362370
McpSchema.JSONRPCResponse.class);
363371
jsonRpcError = jsonRpcResponse.error();
364372
toPropagate = jsonRpcError != null ? new McpError(jsonRpcError)
365-
: new McpError("Can't parse the jsonResponse " + jsonRpcResponse);
373+
: new McpTransportException("Can't parse the jsonResponse " + jsonRpcResponse);
366374
}
367375
catch (IOException ex) {
368-
toPropagate = new RuntimeException("Sending request failed", e);
376+
toPropagate = new McpTransportException("Sending request failed, " + e.getMessage(), e);
369377
logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body);
370378
}
371379

@@ -374,7 +382,11 @@ private Flux<McpSchema.JSONRPCMessage> extractError(ClientResponse response, Str
374382
// invalidate the session
375383
// https://github.com/modelcontextprotocol/typescript-sdk/issues/389
376384
if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) {
377-
return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate));
385+
if (!sessionRepresentation.equals(MISSING_SESSION_ID)) {
386+
return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate));
387+
}
388+
return Mono.error(new McpTransportException("Received 400 BAD REQUEST for session "
389+
+ sessionRepresentation + ". " + toPropagate.getMessage(), toPropagate));
378390
}
379391
return Mono.error(toPropagate);
380392
}).flux();
@@ -403,7 +415,7 @@ private static boolean isEventStream(ClientResponse response) {
403415
}
404416

405417
private static String sessionIdOrPlaceholder(McpTransportSession<?> transportSession) {
406-
return transportSession.sessionId().orElse("[missing_session_id]");
418+
return transportSession.sessionId().orElse(MISSING_SESSION_ID);
407419
}
408420

409421
private Flux<McpSchema.JSONRPCMessage> directResponseFlux(McpSchema.JSONRPCMessage sentMessage,
@@ -421,8 +433,7 @@ private Flux<McpSchema.JSONRPCMessage> directResponseFlux(McpSchema.JSONRPCMessa
421433
}
422434
}
423435
catch (IOException e) {
424-
// TODO: this should be a McpTransportError
425-
s.error(e);
436+
s.error(new McpTransportException(e));
426437
}
427438
}).flatMapIterable(Function.identity());
428439
}
@@ -449,7 +460,7 @@ private Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>> parse(Serve
449460
return Tuples.of(Optional.ofNullable(event.id()), List.of(message));
450461
}
451462
catch (IOException ioException) {
452-
throw new McpError("Error parsing JSON-RPC message: " + event.data());
463+
throw new McpTransportException("Error parsing JSON-RPC message: " + event.data(), ioException);
453464
}
454465
}
455466
else {

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import io.modelcontextprotocol.spec.HttpHeaders;
1616
import io.modelcontextprotocol.spec.McpClientTransport;
17-
import io.modelcontextprotocol.spec.McpError;
1817
import io.modelcontextprotocol.spec.McpSchema;
1918
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
2019
import io.modelcontextprotocol.spec.ProtocolVersions;
@@ -197,8 +196,6 @@ public List<String> protocolVersions() {
197196
* @param handler a function that processes incoming JSON-RPC messages and returns
198197
* responses
199198
* @return a Mono that completes when the connection is fully established
200-
* @throws McpError if there's an error processing SSE events or if an unrecognized
201-
* event type is received
202199
*/
203200
@Override
204201
public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
@@ -215,7 +212,7 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
215212
else {
216213
// TODO: clarify with the spec if multiple events can be
217214
// received
218-
s.error(new McpError("Failed to handle SSE endpoint event"));
215+
s.error(new RuntimeException("Failed to handle SSE endpoint event"));
219216
}
220217
}
221218
else if (MESSAGE_EVENT_TYPE.equals(event.event())) {

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import org.junit.jupiter.api.AfterEach;
1010
import org.junit.jupiter.api.BeforeEach;
11+
import org.junit.jupiter.api.Timeout;
1112
import org.springframework.http.server.reactive.HttpHandler;
1213
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
1314
import org.springframework.web.reactive.function.client.WebClient;
@@ -26,6 +27,7 @@
2627
import reactor.netty.DisposableServer;
2728
import reactor.netty.http.server.HttpServer;
2829

30+
@Timeout(15)
2931
class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests {
3032

3133
private static final int PORT = TestUtil.findAvailablePort();

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import org.junit.jupiter.api.AfterEach;
1010
import org.junit.jupiter.api.BeforeEach;
11+
import org.junit.jupiter.api.Timeout;
1112
import org.springframework.http.server.reactive.HttpHandler;
1213
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
1314
import org.springframework.web.reactive.function.client.WebClient;
@@ -26,6 +27,7 @@
2627
import reactor.netty.DisposableServer;
2728
import reactor.netty.http.server.HttpServer;
2829

30+
@Timeout(15)
2931
class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests {
3032

3133
private static final int PORT = TestUtil.findAvailablePort();

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import org.junit.jupiter.api.AfterEach;
1010
import org.junit.jupiter.api.BeforeEach;
11+
import org.junit.jupiter.api.Timeout;
1112
import org.springframework.http.server.reactive.HttpHandler;
1213
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
1314
import org.springframework.web.reactive.function.client.WebClient;
@@ -26,6 +27,7 @@
2627
import reactor.netty.DisposableServer;
2728
import reactor.netty.http.server.HttpServer;
2829

30+
@Timeout(15)
2931
class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests {
3032

3133
private static final int PORT = TestUtil.findAvailablePort();

0 commit comments

Comments
 (0)