1919import com .fasterxml .jackson .core .type .TypeReference ;
2020import com .fasterxml .jackson .databind .ObjectMapper ;
2121
22- import io .modelcontextprotocol .spec .DefaultMcpTransportContext ;
22+ import io .modelcontextprotocol .server .DefaultMcpTransportContext ;
23+ import io .modelcontextprotocol .server .McpTransportContext ;
24+ import io .modelcontextprotocol .server .McpTransportContextExtractor ;
2325import io .modelcontextprotocol .spec .McpError ;
2426import io .modelcontextprotocol .spec .McpSchema ;
2527import io .modelcontextprotocol .spec .McpStreamableServerSession ;
2628import io .modelcontextprotocol .spec .McpStreamableServerTransport ;
2729import io .modelcontextprotocol .spec .McpStreamableServerTransportProvider ;
28- import io .modelcontextprotocol .spec .McpTransportContext ;
2930import io .modelcontextprotocol .util .Assert ;
3031import jakarta .servlet .AsyncContext ;
3132import jakarta .servlet .ServletException ;
@@ -117,8 +118,7 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
117118 */
118119 private final ConcurrentHashMap <String , McpStreamableServerSession > sessions = new ConcurrentHashMap <>();
119120
120- // TODO: add means to specify this
121- private Function <HttpServletRequest , McpTransportContext > contextExtractor = req -> new DefaultMcpTransportContext ();
121+ private McpTransportContextExtractor <HttpServletRequest > contextExtractor ;
122122
123123 /**
124124 * Flag indicating if the transport is shutting down.
@@ -132,16 +132,19 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
132132 * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC
133133 * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests.
134134 * @param disallowDelete Whether to disallow DELETE requests on the endpoint.
135+ * @param contextExtractor The extractor for transport context from the request.
135136 * @throws IllegalArgumentException if any parameter is null
136137 */
137- public HttpServletStreamableServerTransportProvider (ObjectMapper objectMapper , String mcpEndpoint ,
138- boolean disallowDelete ) {
138+ private HttpServletStreamableServerTransportProvider (ObjectMapper objectMapper , String mcpEndpoint ,
139+ boolean disallowDelete , McpTransportContextExtractor < HttpServletRequest > contextExtractor ) {
139140 Assert .notNull (objectMapper , "ObjectMapper must not be null" );
140141 Assert .notNull (mcpEndpoint , "MCP endpoint must not be null" );
142+ Assert .notNull (contextExtractor , "Context extractor must not be null" );
141143
142144 this .objectMapper = objectMapper ;
143145 this .mcpEndpoint = mcpEndpoint ;
144146 this .disallowDelete = disallowDelete ;
147+ this .contextExtractor = contextExtractor ;
145148 }
146149
147150 @ Override
@@ -224,8 +227,6 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
224227 return ;
225228 }
226229
227- McpTransportContext transportContext = this .contextExtractor .apply (request );
228-
229230 List <String > badRequestErrors = new ArrayList <>();
230231
231232 String accept = request .getHeader (ACCEPT );
@@ -254,6 +255,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
254255
255256 logger .debug ("Handling GET request for session: {}" , sessionId );
256257
258+ McpTransportContext transportContext = this .contextExtractor .extract (request , new DefaultMcpTransportContext ());
259+
257260 try {
258261 response .setContentType (TEXT_EVENT_STREAM );
259262 response .setCharacterEncoding (UTF_8 );
@@ -277,7 +280,9 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
277280 .toIterable ()
278281 .forEach (message -> {
279282 try {
280- sessionTransport .sendMessage (message ).block ();
283+ sessionTransport .sendMessage (message )
284+ .contextWrite (ctx -> ctx .put (McpTransportContext .KEY , transportContext ))
285+ .block ();
281286 }
282287 catch (Exception e ) {
283288 logger .error ("Failed to replay message: {}" , e .getMessage ());
@@ -359,7 +364,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
359364 badRequestErrors .add ("application/json required in Accept header" );
360365 }
361366
362- McpTransportContext transportContext = this .contextExtractor .apply (request );
367+ McpTransportContext transportContext = this .contextExtractor .extract (request , new DefaultMcpTransportContext () );
363368
364369 try {
365370 BufferedReader reader = request .getReader ();
@@ -517,7 +522,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
517522 return ;
518523 }
519524
520- McpTransportContext transportContext = this .contextExtractor .apply (request );
525+ McpTransportContext transportContext = this .contextExtractor .extract (request , new DefaultMcpTransportContext () );
521526
522527 if (request .getHeader (MCP_SESSION_ID ) == null ) {
523528 this .responseError (response , HttpServletResponse .SC_BAD_REQUEST ,
@@ -745,6 +750,8 @@ public static class Builder {
745750
746751 private boolean disallowDelete = false ;
747752
753+ private McpTransportContextExtractor <HttpServletRequest > contextExtractor = (serverRequest , context ) -> context ;
754+
748755 /**
749756 * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
750757 * messages.
@@ -780,6 +787,18 @@ public Builder disallowDelete(boolean disallowDelete) {
780787 return this ;
781788 }
782789
790+ /**
791+ * Sets the context extractor for extracting transport context from the request.
792+ * @param contextExtractor The context extractor to use. Must not be null.
793+ * @return this builder instance
794+ * @throws IllegalArgumentException if contextExtractor is null
795+ */
796+ public Builder contextExtractor (McpTransportContextExtractor <HttpServletRequest > contextExtractor ) {
797+ Assert .notNull (contextExtractor , "Context extractor must not be null" );
798+ this .contextExtractor = contextExtractor ;
799+ return this ;
800+ }
801+
783802 /**
784803 * Builds a new instance of {@link HttpServletStreamableServerTransportProvider}
785804 * with the configured settings.
@@ -791,7 +810,7 @@ public HttpServletStreamableServerTransportProvider build() {
791810 Assert .notNull (this .mcpEndpoint , "MCP endpoint must be set" );
792811
793812 return new HttpServletStreamableServerTransportProvider (this .objectMapper , this .mcpEndpoint ,
794- this .disallowDelete );
813+ this .disallowDelete , this . contextExtractor );
795814 }
796815
797816 }
0 commit comments