diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java index 1cd540f72..21f751d89 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java @@ -47,4 +47,13 @@ public interface McpTransportContext { */ McpTransportContext copy(); + /** + * Sends a notification from the server to the client. + * @param method notification method name + * @param params any parameters or {@code null} + */ + default void sendNotification(String method, Object params) { + throw new UnsupportedOperationException("Not supported in this implementation of MCP transport context"); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/StatelessMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/StatelessMcpTransportContext.java new file mode 100644 index 000000000..b2b0a6cb8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/StatelessMcpTransportContext.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.function.BiConsumer; + +public class StatelessMcpTransportContext implements McpTransportContext { + + private final McpTransportContext delegate; + + private final BiConsumer notificationHandler; + + /** + * Create an empty instance. + */ + public StatelessMcpTransportContext(BiConsumer notificationHandler) { + this(new DefaultMcpTransportContext(), notificationHandler); + } + + private StatelessMcpTransportContext(McpTransportContext delegate, BiConsumer notificationHandler) { + this.delegate = delegate; + this.notificationHandler = notificationHandler; + } + + @Override + public Object get(String key) { + return this.delegate.get(key); + } + + @Override + public void put(String key, Object value) { + this.delegate.put(key, value); + } + + public McpTransportContext copy() { + return new StatelessMcpTransportContext(delegate.copy(), notificationHandler); + } + + @Override + public void sendNotification(String method, Object params) { + notificationHandler.accept(method, params); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 25b003564..041471965 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -4,19 +4,11 @@ package io.modelcontextprotocol.server.transport; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.PrintWriter; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.server.DefaultMcpTransportContext; import io.modelcontextprotocol.server.McpStatelessServerHandler; import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.StatelessMcpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpStatelessServerTransport; @@ -26,8 +18,17 @@ import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; + /** * Implementation of an HttpServlet based {@link McpStatelessServerTransport}. * @@ -123,11 +124,16 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + AtomicInteger nextId = new AtomicInteger(0); + AtomicBoolean upgradedToSse = new AtomicBoolean(false); + BiConsumer notificationHandler = buildNotificationHandler(response, upgradedToSse, nextId); + McpTransportContext transportContext = this.contextExtractor.extract(request, + new StatelessMcpTransportContext(notificationHandler)); String accept = request.getHeader(ACCEPT); if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) { - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, null, upgradedToSse.get(), + nextId.getAndIncrement(), new McpError("Both application/json and text/event-stream required in Accept header")); return; } @@ -149,18 +155,24 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) .block(); - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(HttpServletResponse.SC_OK); - String jsonResponseText = objectMapper.writeValueAsString(jsonrpcResponse); - PrintWriter writer = response.getWriter(); - writer.write(jsonResponseText); - writer.flush(); + if (upgradedToSse.get()) { + sendEvent(response.getWriter(), jsonResponseText, nextId.getAndIncrement()); + } + else { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_OK); + + PrintWriter writer = response.getWriter(); + writer.write(jsonResponseText); + writer.flush(); + } } catch (Exception e) { logger.error("Failed to handle request: {}", e.getMessage()); - this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, jsonrpcRequest.id(), + upgradedToSse.get(), nextId.getAndIncrement(), new McpError("Failed to handle request: " + e.getMessage())); } } @@ -173,23 +185,25 @@ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { } catch (Exception e) { logger.error("Failed to handle notification: {}", e.getMessage()); - this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, null, + upgradedToSse.get(), nextId.getAndIncrement(), new McpError("Failed to handle notification: " + e.getMessage())); } } else { - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, - new McpError("The server accepts either requests or notifications")); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, null, upgradedToSse.get(), + nextId.getAndIncrement(), new McpError("The server accepts either requests or notifications")); } } catch (IllegalArgumentException | IOException e) { logger.error("Failed to deserialize message: {}", e.getMessage()); - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError("Invalid message format")); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, null, upgradedToSse.get(), + nextId.getAndIncrement(), new McpError("Invalid message format")); } catch (Exception e) { logger.error("Unexpected error handling message: {}", e.getMessage()); - this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - new McpError("Unexpected error: " + e.getMessage())); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, null, upgradedToSse.get(), + nextId.getAndIncrement(), new McpError("Unexpected error: " + e.getMessage())); } } @@ -197,17 +211,27 @@ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { * Sends an error response to the client. * @param response The HTTP servlet response * @param httpCode The HTTP status code + * @param upgradedToSse true if the response is upgraded to SSE, false otherwise + * @param eventIdIfNeeded if upgradedToSse, the event ID to use, otherwise ignored * @param mcpError The MCP error to send * @throws IOException If an I/O error occurs */ - private void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException { - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(httpCode); - String jsonError = objectMapper.writeValueAsString(mcpError); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); + private void responseError(HttpServletResponse response, int httpCode, Object requestId, boolean upgradedToSse, + int eventIdIfNeeded, McpError mcpError) throws IOException { + if (upgradedToSse) { + String jsonError = objectMapper.writeValueAsString(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + requestId, null, mcpError.getJsonRpcError())); + sendEvent(response.getWriter(), jsonError, eventIdIfNeeded); + } + else { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(httpCode); + PrintWriter writer = response.getWriter(); + String jsonError = objectMapper.writeValueAsString(mcpError); + writer.write(jsonError); + writer.flush(); + } } /** @@ -303,4 +327,43 @@ public HttpServletStatelessServerTransport build() { } + private BiConsumer buildNotificationHandler(HttpServletResponse response, + AtomicBoolean upgradedToSse, AtomicInteger nextId) { + AtomicBoolean responseInitialized = new AtomicBoolean(false); + + return (notificationMethod, params) -> { + if (responseInitialized.compareAndSet(false, true)) { + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_OK); + } + + upgradedToSse.set(true); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + notificationMethod, params); + try { + sendEvent(response.getWriter(), objectMapper.writeValueAsString(notification), + nextId.getAndIncrement()); + } + catch (IOException e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), null)); + } + }; + } + + private void sendEvent(PrintWriter writer, String data, int id) throws IOException { + // tested with MCP inspector. Event must consist of these two fields and only + // these two fields + writer.write("id: " + id + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java index 4c3f22d76..c43caa356 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -35,12 +35,19 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.web.client.RestClient; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.time.Duration; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import java.util.stream.Stream; import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON; import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM; @@ -61,10 +68,13 @@ class HttpServletStatelessIntegrationTests { private Tomcat tomcat; + private ObjectMapper objectMapper; + @BeforeEach public void before() { + objectMapper = new ObjectMapper(); this.mcpStatelessServerTransport = HttpServletStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) + .objectMapper(objectMapper) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .build(); @@ -219,6 +229,143 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { mcpServer.close(); } + @Test + void testNotifications() throws Exception { + + Tool tool = Tool.builder().name("test").build(); + Tool exceptionTool = Tool.builder().name("exception").build(); + + final int PROGRESS_QTY = 1000; + final String progressMessage = "We're working on it..."; + + var progressToken = UUID.randomUUID().toString(); + var callResponse = new CallToolResult(List.of(), null, null, Map.of("progressToken", progressToken)); + + McpStatelessServerFeatures.SyncToolSpecification toolSpecification = new McpStatelessServerFeatures.SyncToolSpecification( + tool, (transportContext, request) -> { + // Simulate sending progress notifications - send enough to ensure + // that cunked transfer encoding is used + for (int i = 0; i < PROGRESS_QTY; i++) { + transportContext.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, + new McpSchema.ProgressNotification(progressToken, i, (double) PROGRESS_QTY, + progressMessage)); + } + return callResponse; + }); + + McpStatelessServerFeatures.SyncToolSpecification exceptionToolSpecification = new McpStatelessServerFeatures.SyncToolSpecification( + exceptionTool, (transportContext, request) -> { + // send 1 progress so that the response gets upgraded + transportContext.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, + new McpSchema.ProgressNotification(progressToken, 1, 5.0, progressMessage)); + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INVALID_PARAMS, + "bad tool", Map.of())); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(toolSpecification, exceptionToolSpecification) + .build(); + + HttpClient client = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(); + HttpRequest request = HttpRequest.newBuilder() + .method("POST", + HttpRequest.BodyPublishers.ofString( + objectMapper.writeValueAsString(new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + "tools/call", "1", new McpSchema.CallToolRequest("test", Map.of()))))) + .header("Content-Type", APPLICATION_JSON) + .header("Accept", APPLICATION_JSON + "," + TEXT_EVENT_STREAM) + .uri(URI.create("http://localhost:" + PORT + CUSTOM_MESSAGE_ENDPOINT)) + .build(); + + HttpResponse> response = client.send(request, HttpResponse.BodyHandlers.ofLines()); + assertThat(response.headers().firstValue("Transfer-Encoding")).contains("chunked"); + + List responseBody = response.body().toList(); + + assertThat(responseBody).hasSize((PROGRESS_QTY + 1) * 3); // 3 lines per progress + // notification + 4 + // for + // the call result + + Iterator iterator = responseBody.iterator(); + for (int i = 0; i < PROGRESS_QTY; ++i) { + String idLine = iterator.next(); + String dataLine = iterator.next(); + String blankLine = iterator.next(); + + McpSchema.ProgressNotification expectedNotification = new McpSchema.ProgressNotification(progressToken, i, + (double) PROGRESS_QTY, progressMessage); + McpSchema.JSONRPCNotification expectedJsonRpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, McpSchema.METHOD_NOTIFICATION_PROGRESS, expectedNotification); + + assertThat(idLine).isEqualTo("id: " + i); + assertThat(dataLine).isEqualTo("data: " + objectMapper.writeValueAsString(expectedJsonRpcNotification)); + assertThat(blankLine).isBlank(); + } + + String idLine = iterator.next(); + String dataLine = iterator.next(); + String blankLine = iterator.next(); + + assertThat(idLine).isEqualTo("id: " + PROGRESS_QTY); + assertThat(dataLine).isEqualTo("data: " + objectMapper + .writeValueAsString(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, "1", callResponse, null))); + assertThat(blankLine).isBlank(); + + assertThat(iterator.hasNext()).isFalse(); + + // next, test the exception tool + + request = HttpRequest.newBuilder() + .method("POST", + HttpRequest.BodyPublishers.ofString( + objectMapper.writeValueAsString(new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + "tools/call", "1", new McpSchema.CallToolRequest("exception", Map.of()))))) + .header("Content-Type", APPLICATION_JSON) + .header("Accept", APPLICATION_JSON + "," + TEXT_EVENT_STREAM) + .uri(URI.create("http://localhost:" + PORT + CUSTOM_MESSAGE_ENDPOINT)) + .build(); + + response = client.send(request, HttpResponse.BodyHandlers.ofLines()); + assertThat(response.headers().firstValue("Transfer-Encoding")).contains("chunked"); + + responseBody = response.body().toList(); + + assertThat(responseBody).hasSize(6); // 1 progress notification + the error + // response + + iterator = responseBody.iterator(); + + idLine = iterator.next(); + dataLine = iterator.next(); + blankLine = iterator.next(); + + McpSchema.ProgressNotification expectedNotification = new McpSchema.ProgressNotification(progressToken, 1, 5.0, + progressMessage); + McpSchema.JSONRPCNotification expectedJsonRpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, McpSchema.METHOD_NOTIFICATION_PROGRESS, expectedNotification); + + assertThat(idLine).isEqualTo("id: 0"); + assertThat(dataLine).isEqualTo("data: " + objectMapper.writeValueAsString(expectedJsonRpcNotification)); + assertThat(blankLine).isBlank(); + + idLine = iterator.next(); + dataLine = iterator.next(); + blankLine = iterator.next(); + + assertThat(iterator.hasNext()).isFalse(); + + assertThat(idLine).isEqualTo("id: 1"); + assertThat(dataLine).isEqualTo( + "data: " + objectMapper.writeValueAsString(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, "1", + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INVALID_PARAMS, + "bad tool", Map.of())))); + assertThat(blankLine).isBlank(); + + mcpServer.close(); + } + // --------------------------------------- // Tool Structured Output Schema Tests // ---------------------------------------