77import java .io .BufferedReader ;
88import java .io .IOException ;
99import java .io .PrintWriter ;
10+ import java .util .ArrayList ;
11+ import java .util .List ;
1012import java .util .concurrent .ConcurrentHashMap ;
1113import java .util .concurrent .locks .ReentrantLock ;
1214import java .util .function .Function ;
1315
16+ import org .slf4j .Logger ;
17+ import org .slf4j .LoggerFactory ;
18+
1419import com .fasterxml .jackson .core .type .TypeReference ;
1520import com .fasterxml .jackson .databind .ObjectMapper ;
21+
1622import io .modelcontextprotocol .spec .DefaultMcpTransportContext ;
1723import io .modelcontextprotocol .spec .McpError ;
1824import io .modelcontextprotocol .spec .McpSchema ;
2733import jakarta .servlet .http .HttpServlet ;
2834import jakarta .servlet .http .HttpServletRequest ;
2935import jakarta .servlet .http .HttpServletResponse ;
30- import org .slf4j .Logger ;
31- import org .slf4j .LoggerFactory ;
3236import reactor .core .publisher .Mono ;
3337
3438/**
4347 * for the core MCP module, providing streamable HTTP transport functionality without
4448 * Spring dependencies.
4549 *
50+ * @author Zachary German
4651 * @author Christian Tzolov
4752 * @author Dariusz Jędrzejczyk
4853 * @see McpStreamableServerTransportProvider
@@ -72,7 +77,12 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
7277 /**
7378 * Header name for the last message ID used in replay requests.
7479 */
75- private static final String MCP_LAST_ID = "Last-Event-ID" ;
80+ private static final String LAST_EVENT_ID = "Last-Event-ID" ;
81+
82+ /**
83+ * Header name for the response media types accepted by the requester.
84+ */
85+ private static final String ACCEPT = "Accept" ;
7686
7787 /**
7888 * Default base URL for the message endpoint.
@@ -216,13 +226,25 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
216226
217227 McpTransportContext transportContext = this .contextExtractor .apply (request );
218228
219- if (request .getHeader (MCP_SESSION_ID ) == null ) {
220- this .responseError (response , HttpServletResponse .SC_BAD_REQUEST ,
221- new McpError ("Session ID required in mcp-session-id header" ));
222- return ;
229+ List <String > badRequestErrors = new ArrayList <>();
230+
231+ String accept = request .getHeader (ACCEPT );
232+ if (accept == null || !accept .contains (TEXT_EVENT_STREAM )) {
233+ badRequestErrors .add ("text/event-stream required in Accept header" );
223234 }
224235
225236 String sessionId = request .getHeader (MCP_SESSION_ID );
237+
238+ if (sessionId == null || sessionId .isBlank ()) {
239+ badRequestErrors .add ("Session ID required in mcp-session-id header" );
240+ }
241+
242+ if (!badRequestErrors .isEmpty ()) {
243+ String combinedMessage = String .join ("; " , badRequestErrors );
244+ this .responseError (response , HttpServletResponse .SC_BAD_REQUEST , new McpError (combinedMessage ));
245+ return ;
246+ }
247+
226248 McpStreamableServerSession session = this .sessions .get (sessionId );
227249
228250 if (session == null ) {
@@ -246,8 +268,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
246268 sessionId , asyncContext , response .getWriter ());
247269
248270 // Check if this is a replay request
249- if (request .getHeader (MCP_LAST_ID ) != null ) {
250- String lastId = request .getHeader (MCP_LAST_ID );
271+ if (request .getHeader (LAST_EVENT_ID ) != null ) {
272+ String lastId = request .getHeader (LAST_EVENT_ID );
251273
252274 try {
253275 session .replay (lastId )
@@ -327,6 +349,16 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
327349 return ;
328350 }
329351
352+ List <String > badRequestErrors = new ArrayList <>();
353+
354+ String accept = request .getHeader (ACCEPT );
355+ if (accept == null || !accept .contains (TEXT_EVENT_STREAM )) {
356+ badRequestErrors .add ("text/event-stream required in Accept header" );
357+ }
358+ if (accept == null || !accept .contains (APPLICATION_JSON )) {
359+ badRequestErrors .add ("application/json required in Accept header" );
360+ }
361+
330362 McpTransportContext transportContext = this .contextExtractor .apply (request );
331363
332364 try {
@@ -342,6 +374,12 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
342374 // Handle initialization request
343375 if (message instanceof McpSchema .JSONRPCRequest jsonrpcRequest
344376 && jsonrpcRequest .method ().equals (McpSchema .METHOD_INITIALIZE )) {
377+ if (!badRequestErrors .isEmpty ()) {
378+ String combinedMessage = String .join ("; " , badRequestErrors );
379+ this .responseError (response , HttpServletResponse .SC_BAD_REQUEST , new McpError (combinedMessage ));
380+ return ;
381+ }
382+
345383 McpSchema .InitializeRequest initializeRequest = objectMapper .convertValue (jsonrpcRequest .params (),
346384 new TypeReference <McpSchema .InitializeRequest >() {
347385 });
@@ -373,13 +411,18 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
373411 }
374412 }
375413
376- // Handle other messages that require a session
377- if (request .getHeader (MCP_SESSION_ID ) == null ) {
378- this .responseError (response , HttpServletResponse .SC_BAD_REQUEST , new McpError ("Session ID missing" ));
414+ String sessionId = request .getHeader (MCP_SESSION_ID );
415+
416+ if (sessionId == null || sessionId .isBlank ()) {
417+ badRequestErrors .add ("Session ID required in mcp-session-id header" );
418+ }
419+
420+ if (!badRequestErrors .isEmpty ()) {
421+ String combinedMessage = String .join ("; " , badRequestErrors );
422+ this .responseError (response , HttpServletResponse .SC_BAD_REQUEST , new McpError (combinedMessage ));
379423 return ;
380424 }
381425
382- String sessionId = request .getHeader (MCP_SESSION_ID );
383426 McpStreamableServerSession session = this .sessions .get (sessionId );
384427
385428 if (session == null ) {
0 commit comments