Skip to content

Commit 3dac307

Browse files
committed
avoid streamable listening sse duplicate creation
1 parent cc67d8f commit 3dac307

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1111
import io.modelcontextprotocol.spec.HttpHeaders;
1212
import io.modelcontextprotocol.spec.McpError;
13+
import io.modelcontextprotocol.spec.McpLoggableSession;
1314
import io.modelcontextprotocol.spec.McpSchema;
1415
import io.modelcontextprotocol.spec.McpStreamableServerSession;
16+
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1517
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
1618
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
1719
import io.modelcontextprotocol.spec.ProtocolVersions;
@@ -187,12 +189,18 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
187189
return ServerResponse.notFound().build();
188190
}
189191

190-
if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
192+
McpLoggableSession listenedStream = session.getListeningStream();
193+
boolean replayRequest = request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID);
194+
if (replayRequest) {
191195
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
192196
return ServerResponse.ok()
193197
.contentType(MediaType.TEXT_EVENT_STREAM)
194198
.body(session.replay(lastId), ServerSentEvent.class);
195199
}
200+
if (listenedStream instanceof McpStreamableServerSessionStream) {
201+
logger.debug("Listening stream for session: {} exists.", sessionId);
202+
return ServerResponse.ok().build();
203+
}
196204

197205
return ServerResponse.ok()
198206
.contentType(MediaType.TEXT_EVENT_STREAM)
@@ -484,4 +492,4 @@ public WebFluxStreamableServerTransportProvider build() {
484492

485493
}
486494

487-
}
495+
}

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import java.util.concurrent.ConcurrentHashMap;
1111
import java.util.concurrent.locks.ReentrantLock;
1212

13+
import io.modelcontextprotocol.spec.McpLoggableSession;
14+
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1315
import org.slf4j.Logger;
1416
import org.slf4j.LoggerFactory;
1517
import org.springframework.http.HttpStatus;
@@ -252,7 +254,12 @@ private ServerResponse handleGet(ServerRequest request) {
252254
}
253255

254256
logger.debug("Handling GET request for session: {}", sessionId);
255-
257+
McpLoggableSession listenedStream = session.getListeningStream();
258+
boolean replayRequest = request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID);
259+
if (!replayRequest && listenedStream instanceof McpStreamableServerSessionStream) {
260+
logger.debug("Listening stream for session: {} exists.", sessionId);
261+
return ServerResponse.ok().build();
262+
}
256263
try {
257264
return ServerResponse.sse(sseBuilder -> {
258265
sseBuilder.onTimeout(() -> {
@@ -263,9 +270,8 @@ private ServerResponse handleGet(ServerRequest request) {
263270
sessionId, sseBuilder);
264271

265272
// Check if this is a replay request
266-
if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
273+
if (replayRequest) {
267274
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
268-
269275
try {
270276
session.replay(lastId)
271277
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))

mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import java.util.concurrent.ConcurrentHashMap;
1414
import java.util.concurrent.locks.ReentrantLock;
1515

16+
import io.modelcontextprotocol.spec.McpLoggableSession;
17+
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1618
import org.slf4j.Logger;
1719
import org.slf4j.LoggerFactory;
1820

@@ -273,6 +275,13 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
273275
}
274276

275277
logger.debug("Handling GET request for session: {}", sessionId);
278+
McpLoggableSession listenedStream = session.getListeningStream();
279+
boolean replayRequest = request.getHeader(HttpHeaders.LAST_EVENT_ID) != null;
280+
if (!replayRequest && listenedStream instanceof McpStreamableServerSessionStream) {
281+
logger.debug("Listening stream for session: {} exists.", sessionId);
282+
response.setStatus(HttpServletResponse.SC_OK);
283+
return;
284+
}
276285

277286
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
278287

@@ -290,7 +299,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
290299
sessionId, asyncContext, response.getWriter());
291300

292301
// Check if this is a replay request
293-
if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) {
302+
if (replayRequest) {
294303
String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID);
295304

296305
try {

mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ public McpStreamableServerSessionStream listeningStream(McpStreamableServerTrans
142142
return listeningStream;
143143
}
144144

145+
public McpLoggableSession getListeningStream() {
146+
return this.listeningStreamRef.get();
147+
}
148+
145149
// TODO: keep track of history by keeping a map from eventId to stream and then
146150
// iterate over the events using the lastEventId
147151
public Flux<McpSchema.JSONRPCMessage> replay(Object lastEventId) {

0 commit comments

Comments
 (0)