Skip to content

Commit dcd6903

Browse files
Zachary Germantzolov
authored andcommitted
Added 'Accept' header validation and touched up 'Last-Event-ID' header
1 parent 6f072f8 commit dcd6903

File tree

3 files changed

+58
-15
lines changed

3 files changed

+58
-15
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,4 +440,4 @@ public WebFluxStreamableServerTransportProvider build() {
440440

441441
}
442442

443-
}
443+
}

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,4 +651,4 @@ public WebMvcStreamableServerTransportProvider build() {
651651

652652
}
653653

654-
}
654+
}

mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@
77
import java.io.BufferedReader;
88
import java.io.IOException;
99
import java.io.PrintWriter;
10+
import java.util.ArrayList;
11+
import java.util.List;
1012
import java.util.concurrent.ConcurrentHashMap;
1113
import java.util.concurrent.locks.ReentrantLock;
1214
import java.util.function.Function;
1315

16+
import org.slf4j.Logger;
17+
import org.slf4j.LoggerFactory;
18+
1419
import com.fasterxml.jackson.core.type.TypeReference;
1520
import com.fasterxml.jackson.databind.ObjectMapper;
21+
1622
import io.modelcontextprotocol.spec.DefaultMcpTransportContext;
1723
import io.modelcontextprotocol.spec.McpError;
1824
import io.modelcontextprotocol.spec.McpSchema;
@@ -27,8 +33,6 @@
2733
import jakarta.servlet.http.HttpServlet;
2834
import jakarta.servlet.http.HttpServletRequest;
2935
import jakarta.servlet.http.HttpServletResponse;
30-
import org.slf4j.Logger;
31-
import org.slf4j.LoggerFactory;
3236
import reactor.core.publisher.Mono;
3337

3438
/**
@@ -43,6 +47,7 @@
4347
* for the core MCP module, providing streamable HTTP transport functionality without
4448
* Spring dependencies.
4549
*
50+
* @author Zachary German
4651
* @author Christian Tzolov
4752
* @author Dariusz Jędrzejczyk
4853
* @see McpStreamableServerTransportProvider
@@ -72,7 +77,12 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
7277
/**
7378
* Header name for the last message ID used in replay requests.
7479
*/
75-
private static final String MCP_LAST_ID = "Last-Event-ID";
80+
private static final String LAST_EVENT_ID = "Last-Event-ID";
81+
82+
/**
83+
* Header name for the response media types accepted by the requester.
84+
*/
85+
private static final String ACCEPT = "Accept";
7686

7787
/**
7888
* Default base URL for the message endpoint.
@@ -216,13 +226,25 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
216226

217227
McpTransportContext transportContext = this.contextExtractor.apply(request);
218228

219-
if (request.getHeader(MCP_SESSION_ID) == null) {
220-
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST,
221-
new McpError("Session ID required in mcp-session-id header"));
222-
return;
229+
List<String> badRequestErrors = new ArrayList<>();
230+
231+
String accept = request.getHeader(ACCEPT);
232+
if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) {
233+
badRequestErrors.add("text/event-stream required in Accept header");
223234
}
224235

225236
String sessionId = request.getHeader(MCP_SESSION_ID);
237+
238+
if (sessionId == null || sessionId.isBlank()) {
239+
badRequestErrors.add("Session ID required in mcp-session-id header");
240+
}
241+
242+
if (!badRequestErrors.isEmpty()) {
243+
String combinedMessage = String.join("; ", badRequestErrors);
244+
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage));
245+
return;
246+
}
247+
226248
McpStreamableServerSession session = this.sessions.get(sessionId);
227249

228250
if (session == null) {
@@ -246,8 +268,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
246268
sessionId, asyncContext, response.getWriter());
247269

248270
// Check if this is a replay request
249-
if (request.getHeader(MCP_LAST_ID) != null) {
250-
String lastId = request.getHeader(MCP_LAST_ID);
271+
if (request.getHeader(LAST_EVENT_ID) != null) {
272+
String lastId = request.getHeader(LAST_EVENT_ID);
251273

252274
try {
253275
session.replay(lastId)
@@ -327,6 +349,16 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
327349
return;
328350
}
329351

352+
List<String> badRequestErrors = new ArrayList<>();
353+
354+
String accept = request.getHeader(ACCEPT);
355+
if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) {
356+
badRequestErrors.add("text/event-stream required in Accept header");
357+
}
358+
if (accept == null || !accept.contains(APPLICATION_JSON)) {
359+
badRequestErrors.add("application/json required in Accept header");
360+
}
361+
330362
McpTransportContext transportContext = this.contextExtractor.apply(request);
331363

332364
try {
@@ -342,6 +374,12 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
342374
// Handle initialization request
343375
if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest
344376
&& jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) {
377+
if (!badRequestErrors.isEmpty()) {
378+
String combinedMessage = String.join("; ", badRequestErrors);
379+
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage));
380+
return;
381+
}
382+
345383
McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(),
346384
new TypeReference<McpSchema.InitializeRequest>() {
347385
});
@@ -373,13 +411,18 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
373411
}
374412
}
375413

376-
// Handle other messages that require a session
377-
if (request.getHeader(MCP_SESSION_ID) == null) {
378-
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError("Session ID missing"));
414+
String sessionId = request.getHeader(MCP_SESSION_ID);
415+
416+
if (sessionId == null || sessionId.isBlank()) {
417+
badRequestErrors.add("Session ID required in mcp-session-id header");
418+
}
419+
420+
if (!badRequestErrors.isEmpty()) {
421+
String combinedMessage = String.join("; ", badRequestErrors);
422+
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage));
379423
return;
380424
}
381425

382-
String sessionId = request.getHeader(MCP_SESSION_ID);
383426
McpStreamableServerSession session = this.sessions.get(sessionId);
384427

385428
if (session == null) {

0 commit comments

Comments
 (0)