Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.spec.HttpHeaders;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStreamableServerSession;
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
import io.modelcontextprotocol.spec.ProtocolVersions;
Expand Down Expand Up @@ -187,12 +189,19 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
return ServerResponse.notFound().build();
}

McpLoggableSession listenedStream = session.getListeningStream();
if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
return ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
.body(session.replay(lastId), ServerSentEvent.class);
}
if (listenedStream instanceof McpStreamableServerSessionStream) {
logger.debug(
"Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream",
sessionId);
listenedStream.close();
}

return ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
Expand Down Expand Up @@ -484,4 +493,4 @@ public WebFluxStreamableServerTransportProvider build() {

}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
Expand Down Expand Up @@ -252,7 +254,6 @@ private ServerResponse handleGet(ServerRequest request) {
}

logger.debug("Handling GET request for session: {}", sessionId);

try {
return ServerResponse.sse(sseBuilder -> {
sseBuilder.onTimeout(() -> {
Expand All @@ -265,7 +266,6 @@ private ServerResponse handleGet(ServerRequest request) {
// Check if this is a replay request
if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);

try {
session.replay(lastId)
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
Expand All @@ -288,6 +288,13 @@ private ServerResponse handleGet(ServerRequest request) {
}
}
else {
McpLoggableSession listenedStream = session.getListeningStream();
if (listenedStream instanceof McpStreamableServerSessionStream) {
logger.debug(
"Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream",
sessionId);
listenedStream.close();
}
// Establish new listening stream
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
.listeningStream(sessionTransport);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -273,7 +275,6 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}

logger.debug("Handling GET request for session: {}", sessionId);

McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());

try {
Expand Down Expand Up @@ -315,6 +316,13 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}
}
else {
McpLoggableSession listenedStream = session.getListeningStream();
if (listenedStream instanceof McpStreamableServerSessionStream) {
logger.debug(
"Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream",
sessionId);
listenedStream.close();
}
// Establish new listening stream
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
.listeningStream(sessionTransport);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ public McpStreamableServerSessionStream listeningStream(McpStreamableServerTrans
return listeningStream;
}

public McpLoggableSession getListeningStream() {
return this.listeningStreamRef.get();
}

// TODO: keep track of history by keeping a map from eventId to stream and then
// iterate over the events using the lastEventId
public Flux<McpSchema.JSONRPCMessage> replay(Object lastEventId) {
Expand Down