Skip to content

Commit efd322f

Browse files
committed
Support server to client notifications from the stateless transport
The MCP spec allows stateless servers to send notifications to the client during a request. The response needs to be upgraded to SSE and the notifications are send in a stream until the final result is sent. This commit adds a `sendNotification` method to the transport context allowing each transport implementation to implement it or not. In this commit, HttpServletStatelessServerTransport implements the method and when the caller first sends a notification, the response is changed to `TEXT_EVENT_STREAM` and events are then streamed until the final result. This change will allow future features such as logging, list changes, etc. should we ever decide to support sessions in some manner. Even if we don't support sessions, sending progress notifications is a useful feature by itself.
1 parent 4532b61 commit efd322f

File tree

4 files changed

+235
-40
lines changed

4 files changed

+235
-40
lines changed

mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,13 @@ public interface McpTransportContext {
4747
*/
4848
McpTransportContext copy();
4949

50+
/**
51+
* Sends a notification from the server to the client.
52+
* @param method notification method name
53+
* @param params any parameters or {@code null}
54+
*/
55+
default void sendNotification(String method, Object params) {
56+
throw new UnsupportedOperationException("Not supported in this implementation of MCP transport context");
57+
}
58+
5059
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server;
6+
7+
import java.util.function.BiConsumer;
8+
9+
public class StatelessMcpTransportContext implements McpTransportContext {
10+
11+
private final McpTransportContext delegate;
12+
13+
private final BiConsumer<String, Object> notificationHandler;
14+
15+
/**
16+
* Create an empty instance.
17+
*/
18+
public StatelessMcpTransportContext(BiConsumer<String, Object> notificationHandler) {
19+
this(new DefaultMcpTransportContext(), notificationHandler);
20+
}
21+
22+
private StatelessMcpTransportContext(McpTransportContext delegate, BiConsumer<String, Object> notificationHandler) {
23+
this.delegate = delegate;
24+
this.notificationHandler = notificationHandler;
25+
}
26+
27+
@Override
28+
public Object get(String key) {
29+
return this.delegate.get(key);
30+
}
31+
32+
@Override
33+
public void put(String key, Object value) {
34+
this.delegate.put(key, value);
35+
}
36+
37+
public McpTransportContext copy() {
38+
return new StatelessMcpTransportContext(delegate.copy(), notificationHandler);
39+
}
40+
41+
@Override
42+
public void sendNotification(String method, Object params) {
43+
notificationHandler.accept(method, params);
44+
}
45+
46+
}

mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,11 @@
44

55
package io.modelcontextprotocol.server.transport;
66

7-
import java.io.BufferedReader;
8-
import java.io.IOException;
9-
import java.io.PrintWriter;
10-
11-
import org.slf4j.Logger;
12-
import org.slf4j.LoggerFactory;
13-
147
import com.fasterxml.jackson.databind.ObjectMapper;
15-
16-
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
178
import io.modelcontextprotocol.server.McpStatelessServerHandler;
189
import io.modelcontextprotocol.server.McpTransportContext;
1910
import io.modelcontextprotocol.server.McpTransportContextExtractor;
11+
import io.modelcontextprotocol.server.StatelessMcpTransportContext;
2012
import io.modelcontextprotocol.spec.McpError;
2113
import io.modelcontextprotocol.spec.McpSchema;
2214
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
@@ -26,8 +18,17 @@
2618
import jakarta.servlet.http.HttpServlet;
2719
import jakarta.servlet.http.HttpServletRequest;
2820
import jakarta.servlet.http.HttpServletResponse;
21+
import org.slf4j.Logger;
22+
import org.slf4j.LoggerFactory;
2923
import reactor.core.publisher.Mono;
3024

25+
import java.io.BufferedReader;
26+
import java.io.IOException;
27+
import java.io.PrintWriter;
28+
import java.util.concurrent.atomic.AtomicBoolean;
29+
import java.util.concurrent.atomic.AtomicInteger;
30+
import java.util.function.BiConsumer;
31+
3132
/**
3233
* Implementation of an HttpServlet based {@link McpStatelessServerTransport}.
3334
*
@@ -123,7 +124,11 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
123124
return;
124125
}
125126

126-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
127+
AtomicInteger nextId = new AtomicInteger(0);
128+
AtomicBoolean upgradedToSse = new AtomicBoolean(false);
129+
BiConsumer<String, Object> notificationHandler = buildNotificationHandler(response, upgradedToSse, nextId);
130+
McpTransportContext transportContext = this.contextExtractor.extract(request,
131+
new StatelessMcpTransportContext(notificationHandler));
127132

128133
String accept = request.getHeader(ACCEPT);
129134
if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) {
@@ -149,14 +154,19 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
149154
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
150155
.block();
151156

152-
response.setContentType(APPLICATION_JSON);
153-
response.setCharacterEncoding(UTF_8);
154-
response.setStatus(HttpServletResponse.SC_OK);
155-
156157
String jsonResponseText = objectMapper.writeValueAsString(jsonrpcResponse);
157-
PrintWriter writer = response.getWriter();
158-
writer.write(jsonResponseText);
159-
writer.flush();
158+
if (upgradedToSse.get()) {
159+
sendEvent(response.getWriter(), "result", jsonResponseText, nextId.getAndIncrement());
160+
}
161+
else {
162+
response.setContentType(APPLICATION_JSON);
163+
response.setCharacterEncoding(UTF_8);
164+
response.setStatus(HttpServletResponse.SC_OK);
165+
166+
PrintWriter writer = response.getWriter();
167+
writer.write(jsonResponseText);
168+
writer.flush();
169+
}
160170
}
161171
catch (Exception e) {
162172
logger.error("Failed to handle request: {}", e.getMessage());
@@ -303,4 +313,42 @@ public HttpServletStatelessServerTransport build() {
303313

304314
}
305315

316+
private BiConsumer<String, Object> buildNotificationHandler(HttpServletResponse response,
317+
AtomicBoolean upgradedToSse, AtomicInteger nextId) {
318+
AtomicBoolean responseInitialized = new AtomicBoolean(false);
319+
320+
return (notificationMethod, params) -> {
321+
upgradedToSse.set(true);
322+
323+
if (responseInitialized.compareAndSet(false, true)) {
324+
response.setContentType(TEXT_EVENT_STREAM);
325+
response.setCharacterEncoding(UTF_8);
326+
response.setStatus(HttpServletResponse.SC_OK);
327+
}
328+
329+
McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION,
330+
notificationMethod, params);
331+
try {
332+
sendEvent(response.getWriter(), "notification", objectMapper.writeValueAsString(notification),
333+
nextId.getAndIncrement());
334+
}
335+
catch (IOException e) {
336+
logger.error("Failed to handle notification: {}", e.getMessage());
337+
throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
338+
e.getMessage(), null));
339+
}
340+
};
341+
}
342+
343+
private void sendEvent(PrintWriter writer, String eventType, String data, int id) throws IOException {
344+
writer.write("event: " + eventType + "\n");
345+
writer.write("id: " + id + "\n");
346+
writer.write("data: " + data + "\n\n");
347+
writer.flush();
348+
349+
if (writer.checkError()) {
350+
throw new IOException("Client disconnected");
351+
}
352+
}
353+
306354
}

mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java

Lines changed: 115 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,7 @@
44

55
package io.modelcontextprotocol.server;
66

7-
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
8-
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json;
9-
import static org.assertj.core.api.Assertions.assertThat;
10-
import static org.awaitility.Awaitility.await;
11-
12-
import java.time.Duration;
13-
import java.util.List;
14-
import java.util.Map;
15-
import java.util.concurrent.ConcurrentHashMap;
16-
import java.util.concurrent.atomic.AtomicReference;
17-
import java.util.function.BiFunction;
18-
19-
import org.apache.catalina.LifecycleException;
20-
import org.apache.catalina.LifecycleState;
21-
import org.apache.catalina.startup.Tomcat;
22-
import org.junit.jupiter.api.AfterEach;
23-
import org.junit.jupiter.api.BeforeEach;
24-
import org.junit.jupiter.params.ParameterizedTest;
25-
import org.junit.jupiter.params.provider.ValueSource;
26-
import org.springframework.web.client.RestClient;
27-
287
import com.fasterxml.jackson.databind.ObjectMapper;
29-
308
import io.modelcontextprotocol.client.McpClient;
319
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
3210
import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport;
@@ -42,6 +20,36 @@
4220
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
4321
import io.modelcontextprotocol.spec.McpSchema.Tool;
4422
import net.javacrumbs.jsonunit.core.Option;
23+
import org.apache.catalina.LifecycleException;
24+
import org.apache.catalina.LifecycleState;
25+
import org.apache.catalina.startup.Tomcat;
26+
import org.junit.jupiter.api.AfterEach;
27+
import org.junit.jupiter.api.BeforeEach;
28+
import org.junit.jupiter.api.Test;
29+
import org.junit.jupiter.params.ParameterizedTest;
30+
import org.junit.jupiter.params.provider.ValueSource;
31+
import org.springframework.web.client.RestClient;
32+
33+
import java.net.URI;
34+
import java.net.http.HttpClient;
35+
import java.net.http.HttpRequest;
36+
import java.net.http.HttpResponse;
37+
import java.time.Duration;
38+
import java.util.Iterator;
39+
import java.util.List;
40+
import java.util.Map;
41+
import java.util.UUID;
42+
import java.util.concurrent.ConcurrentHashMap;
43+
import java.util.concurrent.atomic.AtomicReference;
44+
import java.util.function.BiFunction;
45+
import java.util.stream.Stream;
46+
47+
import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON;
48+
import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM;
49+
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
50+
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json;
51+
import static org.assertj.core.api.Assertions.assertThat;
52+
import static org.awaitility.Awaitility.await;
4553

4654
class HttpServletStatelessIntegrationTests {
4755

@@ -55,10 +63,13 @@ class HttpServletStatelessIntegrationTests {
5563

5664
private Tomcat tomcat;
5765

66+
private ObjectMapper objectMapper;
67+
5868
@BeforeEach
5969
public void before() {
70+
objectMapper = new ObjectMapper();
6071
this.mcpStatelessServerTransport = HttpServletStatelessServerTransport.builder()
61-
.objectMapper(new ObjectMapper())
72+
.objectMapper(objectMapper)
6273
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
6374
.build();
6475

@@ -213,6 +224,87 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
213224
mcpServer.close();
214225
}
215226

227+
@Test
228+
void testNotifications() throws Exception {
229+
230+
Tool tool = Tool.builder().name("test").build();
231+
232+
final int PROGRESS_QTY = 1000;
233+
final String progressMessage = "We're working on it...";
234+
235+
var progressToken = UUID.randomUUID().toString();
236+
var callResponse = new CallToolResult(List.of(), null, null, Map.of("progressToken", progressToken));
237+
McpStatelessServerFeatures.SyncToolSpecification toolSpecification = new McpStatelessServerFeatures.SyncToolSpecification(
238+
tool, (transportContext, request) -> {
239+
// Simulate sending progress notifications - send enough to ensure
240+
// that cunked transfer encoding is used
241+
for (int i = 0; i < PROGRESS_QTY; i++) {
242+
transportContext.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS,
243+
new McpSchema.ProgressNotification(progressToken, i, 5.0, progressMessage));
244+
}
245+
return callResponse;
246+
});
247+
248+
var mcpServer = McpServer.sync(mcpStatelessServerTransport)
249+
.capabilities(ServerCapabilities.builder().tools(true).build())
250+
.tools(toolSpecification)
251+
.build();
252+
253+
HttpClient client = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build();
254+
HttpRequest request = HttpRequest.newBuilder()
255+
.method("POST",
256+
HttpRequest.BodyPublishers.ofString(
257+
objectMapper.writeValueAsString(new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION,
258+
"tools/call", "1", new McpSchema.CallToolRequest("test", Map.of())))))
259+
.header("Content-Type", APPLICATION_JSON)
260+
.header("Accept", APPLICATION_JSON + "," + TEXT_EVENT_STREAM)
261+
.uri(URI.create("http://localhost:" + PORT + CUSTOM_MESSAGE_ENDPOINT))
262+
.build();
263+
264+
HttpResponse<Stream<String>> response = client.send(request, HttpResponse.BodyHandlers.ofLines());
265+
assertThat(response.headers().firstValue("Transfer-Encoding")).contains("chunked");
266+
267+
List<String> responseBody = response.body().toList();
268+
269+
assertThat(responseBody).hasSize((PROGRESS_QTY + 1) * 4); // 4 lines per progress
270+
// notification + 4
271+
// for
272+
// the call result
273+
274+
Iterator<String> iterator = responseBody.iterator();
275+
for (int i = 0; i < PROGRESS_QTY; ++i) {
276+
String eventLine = iterator.next();
277+
String idLine = iterator.next();
278+
String dataLine = iterator.next();
279+
String blankLine = iterator.next();
280+
281+
McpSchema.ProgressNotification expectedNotification = new McpSchema.ProgressNotification(progressToken, i,
282+
5.0, progressMessage);
283+
McpSchema.JSONRPCNotification expectedJsonRpcNotification = new McpSchema.JSONRPCNotification(
284+
McpSchema.JSONRPC_VERSION, McpSchema.METHOD_NOTIFICATION_PROGRESS, expectedNotification);
285+
286+
assertThat(eventLine).isEqualTo("event: notification");
287+
assertThat(idLine).isEqualTo("id: " + i);
288+
assertThat(dataLine).isEqualTo("data: " + objectMapper.writeValueAsString(expectedJsonRpcNotification));
289+
assertThat(blankLine).isBlank();
290+
}
291+
292+
String eventLine = iterator.next();
293+
String idLine = iterator.next();
294+
String dataLine = iterator.next();
295+
String blankLine = iterator.next();
296+
297+
assertThat(eventLine).isEqualTo("event: result");
298+
assertThat(idLine).isEqualTo("id: " + PROGRESS_QTY);
299+
assertThat(dataLine).isEqualTo("data: " + objectMapper
300+
.writeValueAsString(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, "1", callResponse, null)));
301+
assertThat(blankLine).isBlank();
302+
303+
assertThat(iterator.hasNext()).isFalse();
304+
305+
mcpServer.close();
306+
}
307+
216308
// ---------------------------------------
217309
// Tool Structured Output Schema Tests
218310
// ---------------------------------------

0 commit comments

Comments
 (0)