diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index 44d89eaeb..d97400141 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -200,9 +200,16 @@ private Mono handleGet(ServerRequest request) { .body(Flux.>create(sink -> { WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport( sink); - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session - .listeningStream(sessionTransport); - sink.onDispose(listeningStream::close); + session.listeningStream(sessionTransport) + .doOnNext(serverSessionStream -> sink + .onDispose(() -> serverSessionStream.closeGracefully().subscribe(v -> { + }, error -> logger.warn("Failed to close listening stream gracefully", error)))) + .doOnError(error -> { + logger.error("Failed to create listening stream", error); + sink.error(error); + }) + .subscribe(serverSessionStream -> logger.debug("Listening stream created successfully"), + sink::error); // TODO Clarify why the outer context is not present in the // Flux.create sink? }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); @@ -491,4 +498,4 @@ public WebFluxStreamableServerTransportProvider build() { } -} \ No newline at end of file +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index 3cc104dd4..dd74e3088 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -10,6 +10,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; +import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; @@ -288,13 +289,15 @@ private ServerResponse handleGet(ServerRequest request) { } else { // Establish new listening stream - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + Mono listeningStream = session .listeningStream(sessionTransport); - - sseBuilder.onComplete(() -> { + listeningStream.subscribe(serverSessionStream -> sseBuilder.onComplete(() -> { logger.debug("SSE connection completed for session: {}", sessionId); - listeningStream.close(); - }); + serverSessionStream.close(); + }), error -> { + sseBuilder.error(error); + logger.error("Failed to create listening stream", error); + }, () -> logger.debug("Listening stream created successfully")); } }, Duration.ZERO); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 3cb8d7b15..25effa698 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -13,6 +13,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; +import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -315,33 +316,36 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } else { // Establish new listening stream - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session - .listeningStream(sessionTransport); - - asyncContext.addListener(new jakarta.servlet.AsyncListener() { - @Override - public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection completed for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection timed out for session: {}", sessionId); - listeningStream.close(); - } + session.listeningStream(sessionTransport) + .doOnNext(serverSessionStream -> asyncContext.addListener(new jakarta.servlet.AsyncListener() { + @Override + public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection completed for session: {}", sessionId); + serverSessionStream.close(); + } + + @Override + public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection timed out for session: {}", sessionId); + serverSessionStream.close(); + } + + @Override + public void onError(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection error for session: {}", sessionId); + serverSessionStream.close(); + } + + @Override + public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { + // No action needed + } + })) + .doOnError(error -> { + logger.error("Failed to create listening stream", error); + }) + .subscribe(serverSessionStream -> logger.debug("Listening stream created successfully")); - @Override - public void onError(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection error for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { - // No action needed - } - }); } } catch (Exception e) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index af29ce0ad..30882d91c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -136,10 +136,16 @@ public Mono delete() { * @param transport The dedicated SSE transport stream * @return a stream representation */ - public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) { + public Mono listeningStream(McpStreamableServerTransport transport) { McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); - this.listeningStreamRef.set(listeningStream); - return listeningStream; + McpLoggableSession oldStream = this.listeningStreamRef.getAndSet(listeningStream); + if (oldStream != null) { + logger.debug( + "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", + this.id); + return oldStream.closeGracefully().thenReturn(listeningStream); + } + return Mono.just(listeningStream); } // TODO: keep track of history by keeping a map from eventId to stream and then