Skip to content

Issue #499 session storage pluggable #500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: 0.11.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.time.Duration;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -113,7 +114,7 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
/**
* Map of active client sessions, keyed by session ID.
*/
private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap<>();
private ConcurrentMap<String, McpServerSession> sessions;

/**
* Flag indicating if the transport is shutting down.
Expand Down Expand Up @@ -194,10 +195,16 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
@Deprecated
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, new ConcurrentHashMap<>());
}

private WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval, ConcurrentMap<String, McpServerSession> sessionsMap) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base path must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
Assert.notNull(sessionsMap, "Sessions map must not be null");

this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
Expand All @@ -207,7 +214,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
.build();

this.sessions = sessionsMap;
if (keepAliveInterval != null) {

this.keepAliveScheduler = KeepAliveScheduler
Expand Down Expand Up @@ -458,6 +465,8 @@ public static class Builder {

private Duration keepAliveInterval;

private ConcurrentMap<String, McpServerSession> sessionsMap = new ConcurrentHashMap<>();

/**
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
* messages.
Expand Down Expand Up @@ -519,6 +528,16 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
return this;
}

/**
* Set the concurrentMap of active client sessions, keyed by mcp-session-id.
* @param sessionsMap the map of active client sessions, keyed by mcp-session-id
* @return @return this builder instance
*/
public Builder sessionsMap(ConcurrentMap<String, McpServerSession> sessionsMap) {
this.sessionsMap = sessionsMap;
return this;
}

/**
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
* configured settings.
Expand All @@ -530,7 +549,7 @@ public WebFluxSseServerTransportProvider build() {
Assert.notNull(messageEndpoint, "Message endpoint must be set");

return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
keepAliveInterval);
keepAliveInterval, sessionsMap);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.time.Duration;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
* Implementation of a WebFlux based {@link McpStreamableServerTransportProvider}.
Expand All @@ -60,7 +61,7 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe

private McpStreamableServerSession.Factory sessionFactory;

private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap<>();
private ConcurrentMap<String, McpStreamableServerSession> sessions;

private McpTransportContextExtractor<ServerRequest> contextExtractor;

Expand All @@ -70,7 +71,7 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe

private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
McpTransportContextExtractor<ServerRequest> contextExtractor, boolean disallowDelete,
Duration keepAliveInterval) {
Duration keepAliveInterval, ConcurrentMap<String, McpStreamableServerSession> sessionsMap) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(mcpEndpoint, "Message endpoint must not be null");
Assert.notNull(contextExtractor, "Context extractor must not be null");
Expand All @@ -84,7 +85,7 @@ private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, Stri
.POST(this.mcpEndpoint, this::handlePost)
.DELETE(this.mcpEndpoint, this::handleDelete)
.build();

this.sessions = sessionsMap;
if (keepAliveInterval != null) {
this.keepAliveScheduler = KeepAliveScheduler
.builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values()))
Expand Down Expand Up @@ -401,6 +402,8 @@ public static class Builder {

private Duration keepAliveInterval;

private ConcurrentMap<String, McpStreamableServerSession> sessionsMap = new ConcurrentHashMap<>();

private Builder() {
// used by a static method
}
Expand Down Expand Up @@ -468,6 +471,16 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
return this;
}

/**
* Set the concurrentMap of active client sessions, keyed by mcp-session-id.
* @param sessionsMap the map of active client sessions, keyed by mcp-session-id
* @return @return this builder instance
*/
public Builder sessionsMap(ConcurrentMap<String, McpStreamableServerSession> sessionsMap) {
this.sessionsMap = sessionsMap;
return this;
}

/**
* Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with
* the configured settings.
Expand All @@ -479,7 +492,7 @@ public WebFluxStreamableServerTransportProvider build() {
Assert.notNull(mcpEndpoint, "Message endpoint must be set");

return new WebFluxStreamableServerTransportProvider(objectMapper, mcpEndpoint, contextExtractor,
disallowDelete, keepAliveInterval);
disallowDelete, keepAliveInterval, sessionsMap);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import java.io.PrintWriter;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;

import com.fasterxml.jackson.core.type.TypeReference;
Expand Down Expand Up @@ -100,7 +100,7 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
private final String sseEndpoint;

/** Map of active client sessions, keyed by session ID */
private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>();
private final ConcurrentMap<String, McpServerSession> sessions;

/** Flag indicating if the transport is in the process of shutting down */
private final AtomicBoolean isClosing = new AtomicBoolean(false);
Expand Down Expand Up @@ -164,10 +164,17 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval) {

this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, new ConcurrentHashMap<>());
}

private HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval, ConcurrentMap<String, McpServerSession> sessions) {

this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.sessions = sessions;

if (keepAliveInterval != null) {

Expand Down Expand Up @@ -536,6 +543,8 @@ public static class Builder {

private Duration keepAliveInterval;

private ConcurrentMap<String, McpServerSession> sessionsMap = new ConcurrentHashMap<>();

/**
* Sets the JSON object mapper to use for message serialization/deserialization.
* @param objectMapper The object mapper to use
Expand Down Expand Up @@ -595,6 +604,16 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
return this;
}

/**
* Set the concurrentMap of active client sessions, keyed by mcp-session-id.
* @param sessionsMap the map of active client sessions, keyed by mcp-session-id
* @return @return this builder instance
*/
public Builder sessionsMap(ConcurrentMap<String, McpServerSession> sessionsMap) {
this.sessionsMap = sessionsMap;
return this;
}

/**
* Builds a new instance of HttpServletSseServerTransportProvider with the
* configured settings.
Expand All @@ -609,7 +628,7 @@ public HttpServletSseServerTransportProvider build() {
throw new IllegalStateException("MessageEndpoint must be set");
}
return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
keepAliveInterval);
keepAliveInterval, this.sessionsMap);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.locks.ReentrantLock;

import org.slf4j.Logger;
Expand Down Expand Up @@ -105,7 +106,7 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
/**
* Map of active client sessions, keyed by mcp-session-id.
*/
private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap<>();
private final ConcurrentMap<String, McpStreamableServerSession> sessions;

private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

Expand All @@ -132,7 +133,7 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
*/
private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
boolean disallowDelete, McpTransportContextExtractor<HttpServletRequest> contextExtractor,
Duration keepAliveInterval) {
Duration keepAliveInterval, ConcurrentMap<String, McpStreamableServerSession> sessionsMap) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
Assert.notNull(contextExtractor, "Context extractor must not be null");
Expand All @@ -141,6 +142,7 @@ private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper,
this.mcpEndpoint = mcpEndpoint;
this.disallowDelete = disallowDelete;
this.contextExtractor = contextExtractor;
this.sessions = sessionsMap;

if (keepAliveInterval != null) {

Expand Down Expand Up @@ -773,6 +775,8 @@ public static class Builder {

private Duration keepAliveInterval;

private ConcurrentMap<String, McpStreamableServerSession> sessionsMap = new ConcurrentHashMap<>();

/**
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
* messages.
Expand Down Expand Up @@ -832,6 +836,16 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
return this;
}

/**
* Set the concurrentMap of active client sessions, keyed by mcp-session-id.
* @param sessionsMap the map of active client sessions, keyed by mcp-session-id
* @return @return this builder instance
*/
public Builder sessionsMap(ConcurrentMap<String, McpStreamableServerSession> sessionsMap) {
this.sessionsMap = sessionsMap;
return this;
}

/**
* Builds a new instance of {@link HttpServletStreamableServerTransportProvider}
* with the configured settings.
Expand All @@ -843,7 +857,7 @@ public HttpServletStreamableServerTransportProvider build() {
Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");

return new HttpServletStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint,
this.disallowDelete, this.contextExtractor, this.keepAliveInterval);
this.disallowDelete, this.contextExtractor, this.keepAliveInterval, this.sessionsMap);
}

}
Expand Down