diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index aaf7bab46..8af0fd0d0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -8,6 +8,7 @@ import java.time.Duration; import java.util.List; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -113,7 +114,7 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv /** * Map of active client sessions, keyed by session ID. */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private ConcurrentMap sessions; /** * Flag indicating if the transport is shutting down. @@ -194,10 +195,16 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, new ConcurrentHashMap<>()); + } + + private WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, ConcurrentMap sessionsMap) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.notNull(sessionsMap, "Sessions map must not be null"); this.objectMapper = objectMapper; this.baseUrl = baseUrl; @@ -207,7 +214,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) .build(); - + this.sessions = sessionsMap; if (keepAliveInterval != null) { this.keepAliveScheduler = KeepAliveScheduler @@ -458,6 +465,8 @@ public static class Builder { private Duration keepAliveInterval; + private ConcurrentMap sessionsMap = new ConcurrentHashMap<>(); + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -519,6 +528,16 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Set the concurrentMap of active client sessions, keyed by mcp-session-id. + * @param sessionsMap the map of active client sessions, keyed by mcp-session-id + * @return @return this builder instance + */ + public Builder sessionsMap(ConcurrentMap sessionsMap) { + this.sessionsMap = sessionsMap; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -530,7 +549,7 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(messageEndpoint, "Message endpoint must be set"); return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval); + keepAliveInterval, sessionsMap); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index f3f6c2c33..72caee367 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -38,6 +38,7 @@ import java.time.Duration; import java.util.List; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; /** * Implementation of a WebFlux based {@link McpStreamableServerTransportProvider}. @@ -60,7 +61,7 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe private McpStreamableServerSession.Factory sessionFactory; - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private ConcurrentMap sessions; private McpTransportContextExtractor contextExtractor; @@ -70,7 +71,7 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, McpTransportContextExtractor contextExtractor, boolean disallowDelete, - Duration keepAliveInterval) { + Duration keepAliveInterval, ConcurrentMap sessionsMap) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); @@ -84,7 +85,7 @@ private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, Stri .POST(this.mcpEndpoint, this::handlePost) .DELETE(this.mcpEndpoint, this::handleDelete) .build(); - + this.sessions = sessionsMap; if (keepAliveInterval != null) { this.keepAliveScheduler = KeepAliveScheduler .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) @@ -401,6 +402,8 @@ public static class Builder { private Duration keepAliveInterval; + private ConcurrentMap sessionsMap = new ConcurrentHashMap<>(); + private Builder() { // used by a static method } @@ -468,6 +471,16 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Set the concurrentMap of active client sessions, keyed by mcp-session-id. + * @param sessionsMap the map of active client sessions, keyed by mcp-session-id + * @return @return this builder instance + */ + public Builder sessionsMap(ConcurrentMap sessionsMap) { + this.sessionsMap = sessionsMap; + return this; + } + /** * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with * the configured settings. @@ -479,7 +492,7 @@ public WebFluxStreamableServerTransportProvider build() { Assert.notNull(mcpEndpoint, "Message endpoint must be set"); return new WebFluxStreamableServerTransportProvider(objectMapper, mcpEndpoint, contextExtractor, - disallowDelete, keepAliveInterval); + disallowDelete, keepAliveInterval, sessionsMap); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index ceeea31b1..e0a950d47 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -9,9 +9,9 @@ import java.io.PrintWriter; import java.time.Duration; import java.util.List; -import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; import com.fasterxml.jackson.core.type.TypeReference; @@ -100,7 +100,7 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement private final String sseEndpoint; /** Map of active client sessions, keyed by session ID */ - private final Map sessions = new ConcurrentHashMap<>(); + private final ConcurrentMap sessions; /** Flag indicating if the transport is in the process of shutting down */ private final AtomicBoolean isClosing = new AtomicBoolean(false); @@ -164,10 +164,17 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, new ConcurrentHashMap<>()); + } + + private HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, ConcurrentMap sessions) { + this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.sessions = sessions; if (keepAliveInterval != null) { @@ -536,6 +543,8 @@ public static class Builder { private Duration keepAliveInterval; + private ConcurrentMap sessionsMap = new ConcurrentHashMap<>(); + /** * Sets the JSON object mapper to use for message serialization/deserialization. * @param objectMapper The object mapper to use @@ -595,6 +604,16 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Set the concurrentMap of active client sessions, keyed by mcp-session-id. + * @param sessionsMap the map of active client sessions, keyed by mcp-session-id + * @return @return this builder instance + */ + public Builder sessionsMap(ConcurrentMap sessionsMap) { + this.sessionsMap = sessionsMap; + return this; + } + /** * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. @@ -609,7 +628,7 @@ public HttpServletSseServerTransportProvider build() { throw new IllegalStateException("MessageEndpoint must be set"); } return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval); + keepAliveInterval, this.sessionsMap); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 8b95ec607..516b485d3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -11,6 +11,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.locks.ReentrantLock; import org.slf4j.Logger; @@ -105,7 +106,7 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet /** * Map of active client sessions, keyed by mcp-session-id. */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final ConcurrentMap sessions; private McpTransportContextExtractor contextExtractor; @@ -132,7 +133,7 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet */ private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, boolean disallowDelete, McpTransportContextExtractor contextExtractor, - Duration keepAliveInterval) { + Duration keepAliveInterval, ConcurrentMap sessionsMap) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); @@ -141,6 +142,7 @@ private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, this.mcpEndpoint = mcpEndpoint; this.disallowDelete = disallowDelete; this.contextExtractor = contextExtractor; + this.sessions = sessionsMap; if (keepAliveInterval != null) { @@ -773,6 +775,8 @@ public static class Builder { private Duration keepAliveInterval; + private ConcurrentMap sessionsMap = new ConcurrentHashMap<>(); + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -832,6 +836,16 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Set the concurrentMap of active client sessions, keyed by mcp-session-id. + * @param sessionsMap the map of active client sessions, keyed by mcp-session-id + * @return @return this builder instance + */ + public Builder sessionsMap(ConcurrentMap sessionsMap) { + this.sessionsMap = sessionsMap; + return this; + } + /** * Builds a new instance of {@link HttpServletStreamableServerTransportProvider} * with the configured settings. @@ -843,7 +857,7 @@ public HttpServletStreamableServerTransportProvider build() { Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); return new HttpServletStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, - this.disallowDelete, this.contextExtractor, this.keepAliveInterval); + this.disallowDelete, this.contextExtractor, this.keepAliveInterval, this.sessionsMap); } }