Skip to content

Commit 51fb931

Browse files
committed
McpSyncClient: introduce McpTransportContext
- McpSyncClient should be considered thread-agnostic, and therefore consumers cannot rely on thread locals to propagate "context", e.g. pass down the Servlet request reference in a server context. - This PR introduces a mechanism for populating an McpTransportContext before executing client operations, and reworks the HTTP request customizers to leverage that McpTransportContext. - This introduces a breaking change to the Sync/Async request customizers. Signed-off-by: Daniel Garnier-Moiroux <[email protected]>
1 parent 95ba8e7 commit 51fb931

File tree

30 files changed

+352
-155
lines changed

30 files changed

+352
-155
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
344344
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
345345
}
346346

347-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
347+
McpTransportContext transportContext = this.contextExtractor.extract(request);
348348

349349
return ServerResponse.ok()
350350
.contentType(MediaType.TEXT_EVENT_STREAM)
@@ -401,7 +401,7 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
401401
.bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get()));
402402
}
403403

404-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
404+
McpTransportContext transportContext = this.contextExtractor.extract(request);
405405

406406
return request.bodyToMono(String.class).flatMap(body -> {
407407
try {

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
9797
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
9898
}
9999

100-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
100+
McpTransportContext transportContext = this.contextExtractor.extract(request);
101101

102102
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
103103
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
166166
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
167167
}
168168

169-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
169+
McpTransportContext transportContext = this.contextExtractor.extract(request);
170170

171171
return Mono.defer(() -> {
172172
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
@@ -221,7 +221,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
221221
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
222222
}
223223

224-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
224+
McpTransportContext transportContext = this.contextExtractor.extract(request);
225225

226226
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
227227
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
@@ -309,7 +309,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
309309
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
310310
}
311311

312-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
312+
McpTransportContext transportContext = this.contextExtractor.extract(request);
313313

314314
return Mono.defer(() -> {
315315
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,7 @@ private ServerResponse handleMessage(ServerRequest request) {
397397
}
398398

399399
try {
400-
final McpTransportContext transportContext = this.contextExtractor.extract(request,
401-
new DefaultMcpTransportContext());
400+
final McpTransportContext transportContext = this.contextExtractor.extract(request);
402401

403402
String body = request.body(String.class);
404403
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ private ServerResponse handlePost(ServerRequest request) {
101101
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
102102
}
103103

104-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
104+
McpTransportContext transportContext = this.contextExtractor.extract(request);
105105

106106
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
107107
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ private ServerResponse handleGet(ServerRequest request) {
238238
return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM");
239239
}
240240

241-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
241+
McpTransportContext transportContext = this.contextExtractor.extract(request);
242242

243243
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
244244
return ServerResponse.badRequest().body("Session ID required in mcp-session-id header");
@@ -322,7 +322,7 @@ private ServerResponse handlePost(ServerRequest request) {
322322
.body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON"));
323323
}
324324

325-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
325+
McpTransportContext transportContext = this.contextExtractor.extract(request);
326326

327327
try {
328328
String body = request.body(String.class);
@@ -431,7 +431,7 @@ private ServerResponse handleDelete(ServerRequest request) {
431431
return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
432432
}
433433

434-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
434+
McpTransportContext transportContext = this.contextExtractor.extract(request);
435435

436436
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
437437
return ServerResponse.badRequest().body("Session ID required in mcp-session-id header");

mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@
1111
import java.util.Map;
1212
import java.util.function.Consumer;
1313
import java.util.function.Function;
14+
import java.util.function.Supplier;
1415

16+
import io.modelcontextprotocol.server.McpTransportContext;
1517
import io.modelcontextprotocol.spec.McpClientTransport;
1618
import io.modelcontextprotocol.spec.McpSchema;
17-
import io.modelcontextprotocol.spec.McpTransport;
1819
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
1920
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
2021
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
2122
import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
2223
import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
2324
import io.modelcontextprotocol.spec.McpSchema.Implementation;
2425
import io.modelcontextprotocol.spec.McpSchema.Root;
26+
import io.modelcontextprotocol.spec.McpTransport;
2527
import io.modelcontextprotocol.util.Assert;
2628
import reactor.core.publisher.Mono;
2729

@@ -183,6 +185,8 @@ class SyncSpec {
183185

184186
private Function<ElicitRequest, ElicitResult> elicitationHandler;
185187

188+
private Supplier<McpTransportContext> contextProvider = () -> McpTransportContext.EMPTY;
189+
186190
private SyncSpec(McpClientTransport transport) {
187191
Assert.notNull(transport, "Transport must not be null");
188192
this.transport = transport;
@@ -409,6 +413,22 @@ public SyncSpec progressConsumers(List<Consumer<McpSchema.ProgressNotification>>
409413
return this;
410414
}
411415

416+
/**
417+
* Add a provider of {@link McpTransportContext}, providing a context before
418+
* calling any client operation. This allows to extract thread-locals and hand
419+
* them over to the underlying transport.
420+
* <p>
421+
* There is no direct equivalent in {@link AsyncSpec}. To achieve the same result,
422+
* append {@code contextWrite(McpTransportContext.KEY, context)} to any
423+
* {@link McpAsyncClient} call.
424+
* @param contextProvider A supplier to create a context
425+
* @return This builder for method chaining
426+
*/
427+
public SyncSpec transportContextProvider(Supplier<McpTransportContext> contextProvider) {
428+
this.contextProvider = contextProvider;
429+
return this;
430+
}
431+
412432
/**
413433
* Create an instance of {@link McpSyncClient} with the provided configurations or
414434
* sensible defaults.
@@ -423,7 +443,8 @@ public McpSyncClient build() {
423443
McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);
424444

425445
return new McpSyncClient(
426-
new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures));
446+
new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures),
447+
this.contextProvider);
427448
}
428449

429450
}

0 commit comments

Comments
 (0)