Skip to content

Commit 644f625

Browse files
authored
Merge branch 'main' into feature/add-uri-template-manager-tests
2 parents 611004c + 95ba8e7 commit 644f625

File tree

29 files changed

+785
-86
lines changed

29 files changed

+785
-86
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111

1212
import com.fasterxml.jackson.core.type.TypeReference;
1313
import com.fasterxml.jackson.databind.ObjectMapper;
14+
15+
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
16+
import io.modelcontextprotocol.server.McpTransportContext;
17+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1418
import io.modelcontextprotocol.spec.McpError;
1519
import io.modelcontextprotocol.spec.McpSchema;
1620
import io.modelcontextprotocol.spec.McpServerSession;
@@ -115,6 +119,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
115119
*/
116120
private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap<>();
117121

122+
private McpTransportContextExtractor<ServerRequest> contextExtractor;
123+
118124
/**
119125
* Flag indicating if the transport is shutting down.
120126
*/
@@ -194,15 +200,38 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
194200
@Deprecated
195201
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
196202
String sseEndpoint, Duration keepAliveInterval) {
203+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval,
204+
(serverRequest, context) -> context);
205+
}
206+
207+
/**
208+
* Constructs a new WebFlux SSE server transport provider instance.
209+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
210+
* of MCP messages. Must not be null.
211+
* @param baseUrl webflux message base path
212+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
213+
* messages. This endpoint will be communicated to clients during SSE connection
214+
* setup. Must not be null.
215+
* @param sseEndpoint The SSE endpoint path. Must not be null.
216+
* @param keepAliveInterval The interval for sending keep-alive pings to clients.
217+
* @param contextExtractor The context extractor to use for extracting MCP transport
218+
* context from HTTP requests. Must not be null.
219+
* @throws IllegalArgumentException if either parameter is null
220+
*/
221+
private WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
222+
String sseEndpoint, Duration keepAliveInterval,
223+
McpTransportContextExtractor<ServerRequest> contextExtractor) {
197224
Assert.notNull(objectMapper, "ObjectMapper must not be null");
198225
Assert.notNull(baseUrl, "Message base path must not be null");
199226
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
200227
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
228+
Assert.notNull(contextExtractor, "Context extractor must not be null");
201229

202230
this.objectMapper = objectMapper;
203231
this.baseUrl = baseUrl;
204232
this.messageEndpoint = messageEndpoint;
205233
this.sseEndpoint = sseEndpoint;
234+
this.contextExtractor = contextExtractor;
206235
this.routerFunction = RouterFunctions.route()
207236
.GET(this.sseEndpoint, this::handleSseConnection)
208237
.POST(this.messageEndpoint, this::handleMessage)
@@ -315,6 +344,8 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
315344
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
316345
}
317346

347+
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
348+
318349
return ServerResponse.ok()
319350
.contentType(MediaType.TEXT_EVENT_STREAM)
320351
.body(Flux.<ServerSentEvent<?>>create(sink -> {
@@ -336,7 +367,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
336367
logger.debug("Session {} cancelled", sessionId);
337368
sessions.remove(sessionId);
338369
});
339-
}), ServerSentEvent.class);
370+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class);
340371
}
341372

342373
/**
@@ -370,6 +401,8 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
370401
.bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get()));
371402
}
372403

404+
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
405+
373406
return request.bodyToMono(String.class).flatMap(body -> {
374407
try {
375408
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
@@ -386,7 +419,7 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
386419
logger.error("Failed to deserialize message: {}", e.getMessage());
387420
return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
388421
}
389-
});
422+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
390423
}
391424

392425
private class WebFluxMcpSessionTransport implements McpServerTransport {
@@ -458,6 +491,8 @@ public static class Builder {
458491

459492
private Duration keepAliveInterval;
460493

494+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
495+
461496
/**
462497
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
463498
* messages.
@@ -519,6 +554,22 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
519554
return this;
520555
}
521556

557+
/**
558+
* Sets the context extractor that allows providing the MCP feature
559+
* implementations to inspect HTTP transport level metadata that was present at
560+
* HTTP request processing time. This allows to extract custom headers and other
561+
* useful data for use during execution later on in the process.
562+
* @param contextExtractor The contextExtractor to fill in a
563+
* {@link McpTransportContext}.
564+
* @return this builder instance
565+
* @throws IllegalArgumentException if contextExtractor is null
566+
*/
567+
public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> contextExtractor) {
568+
Assert.notNull(contextExtractor, "contextExtractor must not be null");
569+
this.contextExtractor = contextExtractor;
570+
return this;
571+
}
572+
522573
/**
523574
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
524575
* configured settings.
@@ -530,7 +581,7 @@ public WebFluxSseServerTransportProvider build() {
530581
Assert.notNull(messageEndpoint, "Message endpoint must be set");
531582

532583
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
533-
keepAliveInterval);
584+
keepAliveInterval, contextExtractor);
534585
}
535586

536587
}

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
191191
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
192192
return ServerResponse.ok()
193193
.contentType(MediaType.TEXT_EVENT_STREAM)
194-
.body(session.replay(lastId), ServerSentEvent.class);
194+
.body(session.replay(lastId)
195+
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)),
196+
ServerSentEvent.class);
195197
}
196198

197199
return ServerResponse.ok()
@@ -202,7 +204,9 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
202204
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
203205
.listeningStream(sessionTransport);
204206
sink.onDispose(listeningStream::close);
205-
}), ServerSentEvent.class);
207+
// TODO Clarify why the outer context is not present in the
208+
// Flux.create sink?
209+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class);
206210

207211
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
208212
}
@@ -282,7 +286,10 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {
282286
return true;
283287
}).contextWrite(sink.contextView()).subscribe();
284288
sink.onCancel(streamSubscription);
285-
}), ServerSentEvent.class);
289+
// TODO Clarify why the outer context is not present in the
290+
// Flux.create sink?
291+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)),
292+
ServerSentEvent.class);
286293
}
287294
else {
288295
return ServerResponse.badRequest().bodyValue(new McpError("Unknown message type"));

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
1414
import org.springframework.web.reactive.function.client.WebClient;
1515
import org.springframework.web.reactive.function.server.RouterFunctions;
16+
import org.springframework.web.reactive.function.server.ServerRequest;
1617

1718
import com.fasterxml.jackson.databind.ObjectMapper;
1819

@@ -22,6 +23,7 @@
2223
import io.modelcontextprotocol.server.McpServer;
2324
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
2425
import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification;
26+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2527
import io.modelcontextprotocol.server.TestUtil;
2628
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
2729
import reactor.netty.DisposableServer;
@@ -40,6 +42,11 @@ class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests
4042

4143
private WebFluxSseServerTransportProvider mcpServerTransportProvider;
4244

45+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
46+
tc.put("important", "value");
47+
return tc;
48+
};
49+
4350
@Override
4451
protected void prepareClients(int port, String mcpEndpoint) {
4552

@@ -75,6 +82,7 @@ public void before() {
7582
.objectMapper(new ObjectMapper())
7683
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
7784
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
85+
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
7886
.build();
7987

8088
HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction());

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
1414
import org.springframework.web.reactive.function.client.WebClient;
1515
import org.springframework.web.reactive.function.server.RouterFunctions;
16+
import org.springframework.web.reactive.function.server.ServerRequest;
1617

1718
import com.fasterxml.jackson.databind.ObjectMapper;
1819

@@ -22,6 +23,7 @@
2223
import io.modelcontextprotocol.server.McpServer;
2324
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
2425
import io.modelcontextprotocol.server.McpServer.SyncSpecification;
26+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2527
import io.modelcontextprotocol.server.TestUtil;
2628
import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider;
2729
import reactor.netty.DisposableServer;
@@ -38,6 +40,11 @@ class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrati
3840

3941
private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider;
4042

43+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
44+
tc.put("important", "value");
45+
return tc;
46+
};
47+
4148
@Override
4249
protected void prepareClients(int port, String mcpEndpoint) {
4350

@@ -71,6 +78,7 @@ public void before() {
7178
this.mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider.builder()
7279
.objectMapper(new ObjectMapper())
7380
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
81+
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
7482
.build();
7583

7684
HttpHandler httpHandler = RouterFunctions

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
import com.fasterxml.jackson.core.type.TypeReference;
1515
import com.fasterxml.jackson.databind.ObjectMapper;
16+
17+
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
18+
import io.modelcontextprotocol.server.McpTransportContext;
19+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1620
import io.modelcontextprotocol.spec.McpError;
1721
import io.modelcontextprotocol.spec.McpSchema;
1822
import io.modelcontextprotocol.spec.McpServerTransport;
@@ -106,6 +110,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
106110
*/
107111
private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap<>();
108112

113+
private McpTransportContextExtractor<ServerRequest> contextExtractor;
114+
109115
/**
110116
* Flag indicating if the transport is shutting down.
111117
*/
@@ -177,23 +183,47 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
177183
* messages via HTTP POST. This endpoint will be communicated to clients through the
178184
* SSE connection's initial endpoint event.
179185
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
180-
* * @param keepAliveInterval The interval for sending keep-alive messages to
186+
* @param keepAliveInterval The interval for sending keep-alive messages to clients.
181187
* @throws IllegalArgumentException if any parameter is null
182188
* @deprecated Use the builder {@link #builder()} instead for better configuration
183189
* options.
184190
*/
185191
@Deprecated
186192
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
187193
String sseEndpoint, Duration keepAliveInterval) {
194+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval,
195+
(serverRequest, context) -> context);
196+
}
197+
198+
/**
199+
* Constructs a new WebMvcSseServerTransportProvider instance.
200+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
201+
* of messages.
202+
* @param baseUrl The base URL for the message endpoint, used to construct the full
203+
* endpoint URL for clients.
204+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
205+
* messages via HTTP POST. This endpoint will be communicated to clients through the
206+
* SSE connection's initial endpoint event.
207+
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
208+
* @param keepAliveInterval The interval for sending keep-alive messages to clients.
209+
* @param contextExtractor The contextExtractor to fill in a
210+
* {@link McpTransportContext}.
211+
* @throws IllegalArgumentException if any parameter is null
212+
*/
213+
private WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
214+
String sseEndpoint, Duration keepAliveInterval,
215+
McpTransportContextExtractor<ServerRequest> contextExtractor) {
188216
Assert.notNull(objectMapper, "ObjectMapper must not be null");
189217
Assert.notNull(baseUrl, "Message base URL must not be null");
190218
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
191219
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
220+
Assert.notNull(contextExtractor, "Context extractor must not be null");
192221

193222
this.objectMapper = objectMapper;
194223
this.baseUrl = baseUrl;
195224
this.messageEndpoint = messageEndpoint;
196225
this.sseEndpoint = sseEndpoint;
226+
this.contextExtractor = contextExtractor;
197227
this.routerFunction = RouterFunctions.route()
198228
.GET(this.sseEndpoint, this::handleSseConnection)
199229
.POST(this.messageEndpoint, this::handleMessage)
@@ -367,11 +397,17 @@ private ServerResponse handleMessage(ServerRequest request) {
367397
}
368398

369399
try {
400+
final McpTransportContext transportContext = this.contextExtractor.extract(request,
401+
new DefaultMcpTransportContext());
402+
370403
String body = request.body(String.class);
371404
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
372405

373406
// Process the message through the session's handle method
374-
session.handle(message).block(); // Block for WebMVC compatibility
407+
session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); // Block
408+
// for
409+
// WebMVC
410+
// compatibility
375411

376412
return ServerResponse.ok().build();
377413
}
@@ -517,6 +553,8 @@ public static class Builder {
517553

518554
private Duration keepAliveInterval;
519555

556+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
557+
520558
/**
521559
* Sets the JSON object mapper to use for message serialization/deserialization.
522560
* @param objectMapper The object mapper to use
@@ -576,6 +614,22 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
576614
return this;
577615
}
578616

617+
/**
618+
* Sets the context extractor that allows providing the MCP feature
619+
* implementations to inspect HTTP transport level metadata that was present at
620+
* HTTP request processing time. This allows to extract custom headers and other
621+
* useful data for use during execution later on in the process.
622+
* @param contextExtractor The contextExtractor to fill in a
623+
* {@link McpTransportContext}.
624+
* @return this builder instance
625+
* @throws IllegalArgumentException if contextExtractor is null
626+
*/
627+
public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> contextExtractor) {
628+
Assert.notNull(contextExtractor, "contextExtractor must not be null");
629+
this.contextExtractor = contextExtractor;
630+
return this;
631+
}
632+
579633
/**
580634
* Builds a new instance of WebMvcSseServerTransportProvider with the configured
581635
* settings.
@@ -587,7 +641,7 @@ public WebMvcSseServerTransportProvider build() {
587641
throw new IllegalStateException("MessageEndpoint must be set");
588642
}
589643
return new WebMvcSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
590-
keepAliveInterval);
644+
keepAliveInterval, contextExtractor);
591645
}
592646

593647
}

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.springframework.web.reactive.function.client.WebClient;
1818
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
1919
import org.springframework.web.servlet.function.RouterFunction;
20+
import org.springframework.web.servlet.function.ServerRequest;
2021
import org.springframework.web.servlet.function.ServerResponse;
2122

2223
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -39,6 +40,11 @@ class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests
3940

4041
private WebMvcSseServerTransportProvider mcpServerTransportProvider;
4142

43+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
44+
tc.put("important", "value");
45+
return tc;
46+
};
47+
4248
@Override
4349
protected void prepareClients(int port, String mcpEndpoint) {
4450

@@ -60,6 +66,7 @@ public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
6066
return WebMvcSseServerTransportProvider.builder()
6167
.objectMapper(new ObjectMapper())
6268
.messageEndpoint(MESSAGE_ENDPOINT)
69+
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
6370
.build();
6471
}
6572

0 commit comments

Comments
 (0)