Skip to content

Commit f13fef3

Browse files
committed
McpSyncClient: introduce McpTransportContext
- McpSyncClient should be considered thread-agnostic, and therefore consumers cannot rely on thread locals to propagate "context", e.g. pass down the Servlet request reference in a server context. - This PR introduces a mechanism for populating an McpTransportContext before executing client operations, and reworks the HTTP request customizers to leverage that McpTransportContext. - This introduces a breaking change to the Sync/Async request customizers. Signed-off-by: Daniel Garnier-Moiroux <[email protected]>
1 parent 95ba8e7 commit f13fef3

File tree

30 files changed

+316
-153
lines changed

30 files changed

+316
-153
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
344344
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
345345
}
346346

347-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
347+
McpTransportContext transportContext = this.contextExtractor.extract(request);
348348

349349
return ServerResponse.ok()
350350
.contentType(MediaType.TEXT_EVENT_STREAM)
@@ -401,7 +401,7 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
401401
.bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get()));
402402
}
403403

404-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
404+
McpTransportContext transportContext = this.contextExtractor.extract(request);
405405

406406
return request.bodyToMono(String.class).flatMap(body -> {
407407
try {

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
9797
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
9898
}
9999

100-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
100+
McpTransportContext transportContext = this.contextExtractor.extract(request);
101101

102102
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
103103
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
166166
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
167167
}
168168

169-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
169+
McpTransportContext transportContext = this.contextExtractor.extract(request);
170170

171171
return Mono.defer(() -> {
172172
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
@@ -221,7 +221,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
221221
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
222222
}
223223

224-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
224+
McpTransportContext transportContext = this.contextExtractor.extract(request);
225225

226226
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
227227
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
@@ -309,7 +309,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
309309
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
310310
}
311311

312-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
312+
McpTransportContext transportContext = this.contextExtractor.extract(request);
313313

314314
return Mono.defer(() -> {
315315
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,7 @@ private ServerResponse handleMessage(ServerRequest request) {
397397
}
398398

399399
try {
400-
final McpTransportContext transportContext = this.contextExtractor.extract(request,
401-
new DefaultMcpTransportContext());
400+
final McpTransportContext transportContext = this.contextExtractor.extract(request);
402401

403402
String body = request.body(String.class);
404403
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ private ServerResponse handlePost(ServerRequest request) {
101101
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
102102
}
103103

104-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
104+
McpTransportContext transportContext = this.contextExtractor.extract(request);
105105

106106
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
107107
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ private ServerResponse handleGet(ServerRequest request) {
238238
return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM");
239239
}
240240

241-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
241+
McpTransportContext transportContext = this.contextExtractor.extract(request);
242242

243243
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
244244
return ServerResponse.badRequest().body("Session ID required in mcp-session-id header");
@@ -322,7 +322,7 @@ private ServerResponse handlePost(ServerRequest request) {
322322
.body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON"));
323323
}
324324

325-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
325+
McpTransportContext transportContext = this.contextExtractor.extract(request);
326326

327327
try {
328328
String body = request.body(String.class);
@@ -431,7 +431,7 @@ private ServerResponse handleDelete(ServerRequest request) {
431431
return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
432432
}
433433

434-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
434+
McpTransportContext transportContext = this.contextExtractor.extract(request);
435435

436436
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
437437
return ServerResponse.badRequest().body("Session ID required in mcp-session-id header");

mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@
1111
import java.util.Map;
1212
import java.util.function.Consumer;
1313
import java.util.function.Function;
14+
import java.util.function.Supplier;
1415

16+
import io.modelcontextprotocol.server.McpTransportContext;
1517
import io.modelcontextprotocol.spec.McpClientTransport;
1618
import io.modelcontextprotocol.spec.McpSchema;
17-
import io.modelcontextprotocol.spec.McpTransport;
1819
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
1920
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
2021
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
2122
import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
2223
import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
2324
import io.modelcontextprotocol.spec.McpSchema.Implementation;
2425
import io.modelcontextprotocol.spec.McpSchema.Root;
26+
import io.modelcontextprotocol.spec.McpTransport;
2527
import io.modelcontextprotocol.util.Assert;
2628
import reactor.core.publisher.Mono;
2729

@@ -183,6 +185,8 @@ class SyncSpec {
183185

184186
private Function<ElicitRequest, ElicitResult> elicitationHandler;
185187

188+
private Supplier<McpTransportContext> contextProvider = () -> McpTransportContext.EMPTY;
189+
186190
private SyncSpec(McpClientTransport transport) {
187191
Assert.notNull(transport, "Transport must not be null");
188192
this.transport = transport;
@@ -409,6 +413,22 @@ public SyncSpec progressConsumers(List<Consumer<McpSchema.ProgressNotification>>
409413
return this;
410414
}
411415

416+
/**
417+
* Add a provider of {@link McpTransportContext}, providing a context before
418+
* calling any client operation. This allows to extract thread-locals and hand
419+
* them over to the underlying transport.
420+
* <p>
421+
* There is no direct equivalent in {@link AsyncSpec}. To achieve the same result,
422+
* append {@code contextWrite(McpTransportContext.KEY, context)} to any
423+
* {@link McpAsyncClient} call.
424+
* @param contextProvider A supplier to create a context
425+
* @return This builder for method chaining
426+
*/
427+
public SyncSpec transportContextProvider(Supplier<McpTransportContext> contextProvider) {
428+
this.contextProvider = contextProvider;
429+
return this;
430+
}
431+
412432
/**
413433
* Create an instance of {@link McpSyncClient} with the provided configurations or
414434
* sensible defaults.
@@ -423,7 +443,8 @@ public McpSyncClient build() {
423443
McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);
424444

425445
return new McpSyncClient(
426-
new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures));
446+
new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures),
447+
this.contextProvider);
427448
}
428449

429450
}

mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@
55
package io.modelcontextprotocol.client;
66

77
import java.time.Duration;
8+
import java.util.function.Supplier;
89

910
import org.slf4j.Logger;
1011
import org.slf4j.LoggerFactory;
1112

13+
import io.modelcontextprotocol.server.McpTransportContext;
1214
import io.modelcontextprotocol.spec.McpSchema;
1315
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
1416
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
1517
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
1618
import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult;
1719
import io.modelcontextprotocol.util.Assert;
20+
import reactor.core.publisher.Mono;
1821

1922
/**
2023
* A synchronous client implementation for the Model Context Protocol (MCP) that wraps an
@@ -63,14 +66,20 @@ public class McpSyncClient implements AutoCloseable {
6366

6467
private final McpAsyncClient delegate;
6568

69+
private final Supplier<McpTransportContext> contextProvider;
70+
6671
/**
6772
* Create a new McpSyncClient with the given delegate.
6873
* @param delegate the asynchronous kernel on top of which this synchronous client
6974
* provides a blocking API.
75+
* @param contextProvider the supplier of context before calling any non-blocking
76+
* operation on underlying delegate
7077
*/
71-
McpSyncClient(McpAsyncClient delegate) {
78+
McpSyncClient(McpAsyncClient delegate, Supplier<McpTransportContext> contextProvider) {
7279
Assert.notNull(delegate, "The delegate can not be null");
80+
Assert.notNull(contextProvider, "The contextProvider can not be null");
7381
this.delegate = delegate;
82+
this.contextProvider = contextProvider;
7483
}
7584

7685
/**
@@ -177,14 +186,14 @@ public boolean closeGracefully() {
177186
public McpSchema.InitializeResult initialize() {
178187
// TODO: block takes no argument here as we assume the async client is
179188
// configured with a requestTimeout at all times
180-
return this.delegate.initialize().block();
189+
return withProvidedContext(this.delegate.initialize()).block();
181190
}
182191

183192
/**
184193
* Send a roots/list_changed notification.
185194
*/
186195
public void rootsListChangedNotification() {
187-
this.delegate.rootsListChangedNotification().block();
196+
withProvidedContext(this.delegate.rootsListChangedNotification()).block();
188197
}
189198

190199
/**
@@ -206,7 +215,7 @@ public void removeRoot(String rootUri) {
206215
* @return
207216
*/
208217
public Object ping() {
209-
return this.delegate.ping().block();
218+
return withProvidedContext(this.delegate.ping()).block();
210219
}
211220

212221
// --------------------------
@@ -224,7 +233,8 @@ public Object ping() {
224233
* Boolean indicating if the execution failed (true) or succeeded (false/absent)
225234
*/
226235
public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest) {
227-
return this.delegate.callTool(callToolRequest).block();
236+
return withProvidedContext(this.delegate.callTool(callToolRequest)).block();
237+
228238
}
229239

230240
/**
@@ -234,7 +244,7 @@ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolReque
234244
* pagination if more tools are available
235245
*/
236246
public McpSchema.ListToolsResult listTools() {
237-
return this.delegate.listTools().block();
247+
return withProvidedContext(this.delegate.listTools()).block();
238248
}
239249

240250
/**
@@ -245,7 +255,8 @@ public McpSchema.ListToolsResult listTools() {
245255
* pagination if more tools are available
246256
*/
247257
public McpSchema.ListToolsResult listTools(String cursor) {
248-
return this.delegate.listTools(cursor).block();
258+
return withProvidedContext(this.delegate.listTools(cursor)).block();
259+
249260
}
250261

251262
// --------------------------
@@ -257,7 +268,8 @@ public McpSchema.ListToolsResult listTools(String cursor) {
257268
* @return The list of all resources result
258269
*/
259270
public McpSchema.ListResourcesResult listResources() {
260-
return this.delegate.listResources().block();
271+
return withProvidedContext(this.delegate.listResources()).block();
272+
261273
}
262274

263275
/**
@@ -266,7 +278,8 @@ public McpSchema.ListResourcesResult listResources() {
266278
* @return The list of resources result
267279
*/
268280
public McpSchema.ListResourcesResult listResources(String cursor) {
269-
return this.delegate.listResources(cursor).block();
281+
return withProvidedContext(this.delegate.listResources(cursor)).block();
282+
270283
}
271284

272285
/**
@@ -275,7 +288,8 @@ public McpSchema.ListResourcesResult listResources(String cursor) {
275288
* @return the resource content.
276289
*/
277290
public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) {
278-
return this.delegate.readResource(resource).block();
291+
return withProvidedContext(this.delegate.readResource(resource)).block();
292+
279293
}
280294

281295
/**
@@ -284,15 +298,17 @@ public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) {
284298
* @return the resource content.
285299
*/
286300
public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest readResourceRequest) {
287-
return this.delegate.readResource(readResourceRequest).block();
301+
return withProvidedContext(this.delegate.readResource(readResourceRequest)).block();
302+
288303
}
289304

290305
/**
291306
* Retrieves the list of all resource templates provided by the server.
292307
* @return The list of all resource templates result.
293308
*/
294309
public McpSchema.ListResourceTemplatesResult listResourceTemplates() {
295-
return this.delegate.listResourceTemplates().block();
310+
return withProvidedContext(this.delegate.listResourceTemplates()).block();
311+
296312
}
297313

298314
/**
@@ -304,7 +320,8 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() {
304320
* @return The list of resource templates result.
305321
*/
306322
public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor) {
307-
return this.delegate.listResourceTemplates(cursor).block();
323+
return withProvidedContext(this.delegate.listResourceTemplates(cursor)).block();
324+
308325
}
309326

310327
/**
@@ -317,7 +334,8 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor
317334
* subscribe to.
318335
*/
319336
public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) {
320-
this.delegate.subscribeResource(subscribeRequest).block();
337+
withProvidedContext(this.delegate.subscribeResource(subscribeRequest)).block();
338+
321339
}
322340

323341
/**
@@ -326,7 +344,8 @@ public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) {
326344
* to unsubscribe from.
327345
*/
328346
public void unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) {
329-
this.delegate.unsubscribeResource(unsubscribeRequest).block();
347+
withProvidedContext(this.delegate.unsubscribeResource(unsubscribeRequest)).block();
348+
330349
}
331350

332351
// --------------------------
@@ -338,7 +357,7 @@ public void unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest)
338357
* @return The list of all prompts result.
339358
*/
340359
public ListPromptsResult listPrompts() {
341-
return this.delegate.listPrompts().block();
360+
return withProvidedContext(this.delegate.listPrompts()).block();
342361
}
343362

344363
/**
@@ -347,19 +366,21 @@ public ListPromptsResult listPrompts() {
347366
* @return The list of prompts result.
348367
*/
349368
public ListPromptsResult listPrompts(String cursor) {
350-
return this.delegate.listPrompts(cursor).block();
369+
return withProvidedContext(this.delegate.listPrompts(cursor)).block();
370+
351371
}
352372

353373
public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) {
354-
return this.delegate.getPrompt(getPromptRequest).block();
374+
return withProvidedContext(this.delegate.getPrompt(getPromptRequest)).block();
355375
}
356376

357377
/**
358378
* Client can set the minimum logging level it wants to receive from the server.
359379
* @param loggingLevel the min logging level
360380
*/
361381
public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
362-
this.delegate.setLoggingLevel(loggingLevel).block();
382+
withProvidedContext(this.delegate.setLoggingLevel(loggingLevel)).block();
383+
363384
}
364385

365386
/**
@@ -369,7 +390,18 @@ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
369390
* @return the completion result containing suggested values.
370391
*/
371392
public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) {
372-
return this.delegate.completeCompletion(completeRequest).block();
393+
return withProvidedContext(this.delegate.completeCompletion(completeRequest)).block();
394+
395+
}
396+
397+
/**
398+
* For a given action, on assembly, capture the "context" via the
399+
* {@link #contextProvider} and store it in the Reactor context.
400+
* @param action the action to perform
401+
* @return the result of the action
402+
*/
403+
private <T> Mono<T> withProvidedContext(Mono<T> action) {
404+
return action.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, this.contextProvider.get()));
373405
}
374406

375407
}

0 commit comments

Comments
 (0)