Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.spec.HttpHeaders;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStreamableServerSession;
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
import io.modelcontextprotocol.spec.ProtocolVersions;
Expand Down Expand Up @@ -187,12 +189,18 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
return ServerResponse.notFound().build();
}

if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
McpLoggableSession listenedStream = session.getListeningStream();
boolean replayRequest = request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID);
if (replayRequest) {
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
return ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
.body(session.replay(lastId), ServerSentEvent.class);
}
if (listenedStream instanceof McpStreamableServerSessionStream) {
logger.debug("Listening stream for session: {} exists.", sessionId);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if(logger.isDebugEnabled) ...

return ServerResponse.ok().build();
}

return ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
Expand Down Expand Up @@ -484,4 +492,4 @@ public WebFluxStreamableServerTransportProvider build() {

}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
Expand Down Expand Up @@ -252,7 +254,12 @@ private ServerResponse handleGet(ServerRequest request) {
}

logger.debug("Handling GET request for session: {}", sessionId);

McpLoggableSession listenedStream = session.getListeningStream();
boolean replayRequest = request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID);
if (!replayRequest && listenedStream instanceof McpStreamableServerSessionStream) {
logger.debug("Listening stream for session: {} exists.", sessionId);
return ServerResponse.ok().build();
}
try {
return ServerResponse.sse(sseBuilder -> {
sseBuilder.onTimeout(() -> {
Expand All @@ -263,9 +270,8 @@ private ServerResponse handleGet(ServerRequest request) {
sessionId, sseBuilder);

// Check if this is a replay request
if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
if (replayRequest) {
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);

try {
session.replay(lastId)
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -273,6 +275,13 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}

logger.debug("Handling GET request for session: {}", sessionId);
McpLoggableSession listenedStream = session.getListeningStream();
boolean replayRequest = request.getHeader(HttpHeaders.LAST_EVENT_ID) != null;
if (!replayRequest && listenedStream instanceof McpStreamableServerSessionStream) {
logger.debug("Listening stream for session: {} exists.", sessionId);
response.setStatus(HttpServletResponse.SC_OK);
return;
}

McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());

Expand All @@ -290,7 +299,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
sessionId, asyncContext, response.getWriter());

// Check if this is a replay request
if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) {
if (replayRequest) {
String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID);

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ public McpStreamableServerSessionStream listeningStream(McpStreamableServerTrans
return listeningStream;
}

public McpLoggableSession getListeningStream() {
return this.listeningStreamRef.get();
}

// TODO: keep track of history by keeping a map from eventId to stream and then
// iterate over the events using the lastEventId
public Flux<McpSchema.JSONRPCMessage> replay(Object lastEventId) {
Expand Down