diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 65416b256..65545695c 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -91,6 +91,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private final String sseEndpoint; + private final String publishedMessageEndpoint; + private final RouterFunction routerFunction; private McpServerSession.Factory sessionFactory; @@ -105,6 +107,36 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private volatile boolean isClosing = false; + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param publishedMessageEndpoint The published endpoint URI that will be sent to clients + * (might differ from the actual messageEndpoint in case of proxies, load balancers, etc.) + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @throws IllegalArgumentException if any parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, + String publishedMessageEndpoint, String sseEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(publishedMessageEndpoint, "Published message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.publishedMessageEndpoint = publishedMessageEndpoint; + this.sseEndpoint = sseEndpoint; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + } + /** * Constructs a new WebMvcSseServerTransportProvider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization @@ -122,6 +154,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag this.objectMapper = objectMapper; this.messageEndpoint = messageEndpoint; + this.publishedMessageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) @@ -248,7 +281,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { try { sseBuilder.id(sessionId) .event(ENDPOINT_EVENT_TYPE) - .data(messageEndpoint + "?sessionId=" + sessionId); + .data(publishedMessageEndpoint + "?sessionId=" + sessionId); } catch (Exception e) { logger.error("Failed to send initial endpoint event: {}", e.getMessage());