Skip to content

Commit 629464b

Browse files
authored
Add McpTransportContext to McpSyncClient (#522)
* McpSyncClient: introduce McpTransportContext - McpSyncClient should be considered thread-agnostic, and therefore consumers cannot rely on thread locals to propagate "context", e.g. pass down the Servlet request reference in a server context. - This PR introduces a mechanism for populating an McpTransportContext before executing client operations, and reworks the HTTP request customizers to leverage that McpTransportContext. - Move McpTransportContext from server to common package for shared client/server usage - Make McpTransportContext immutable by removing put() and copy() methods - Add static create() factory method for creating contexts with predefined data - Update McpTransportContextExtractor to return context instead of modifying one - Replace DefaultMcpTransportContext mutable implementation with immutable version - Update all transport implementations to use McpTransportContext.EMPTY as default - Rename *HttpRequestCustomizer -> *HttpClientRequestCustomizer - Add end-to-end McpTransportContextIntegrationTests - This PR introduces a breaking change to the Sync/Async request customizers. Signed-off-by: Daniel Garnier-Moiroux <[email protected]>
1 parent 95ba8e7 commit 629464b

File tree

52 files changed

+746
-327
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+746
-327
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
import com.fasterxml.jackson.core.type.TypeReference;
1313
import com.fasterxml.jackson.databind.ObjectMapper;
1414

15-
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
16-
import io.modelcontextprotocol.server.McpTransportContext;
15+
import io.modelcontextprotocol.common.McpTransportContext;
1716
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1817
import io.modelcontextprotocol.spec.McpError;
1918
import io.modelcontextprotocol.spec.McpSchema;
@@ -201,7 +200,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
201200
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
202201
String sseEndpoint, Duration keepAliveInterval) {
203202
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval,
204-
(serverRequest, context) -> context);
203+
(serverRequest) -> McpTransportContext.EMPTY);
205204
}
206205

207206
/**
@@ -344,7 +343,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
344343
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
345344
}
346345

347-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
346+
McpTransportContext transportContext = this.contextExtractor.extract(request);
348347

349348
return ServerResponse.ok()
350349
.contentType(MediaType.TEXT_EVENT_STREAM)
@@ -401,7 +400,7 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
401400
.bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get()));
402401
}
403402

404-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
403+
McpTransportContext transportContext = this.contextExtractor.extract(request);
405404

406405
return request.bodyToMono(String.class).flatMap(body -> {
407406
try {
@@ -491,7 +490,8 @@ public static class Builder {
491490

492491
private Duration keepAliveInterval;
493492

494-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
493+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
494+
serverRequest) -> McpTransportContext.EMPTY;
495495

496496
/**
497497
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
package io.modelcontextprotocol.server.transport;
66

77
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import io.modelcontextprotocol.common.McpTransportContext;
89
import io.modelcontextprotocol.server.McpStatelessServerHandler;
9-
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
1010
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1111
import io.modelcontextprotocol.spec.McpError;
1212
import io.modelcontextprotocol.spec.McpSchema;
1313
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
14-
import io.modelcontextprotocol.server.McpTransportContext;
1514
import io.modelcontextprotocol.util.Assert;
1615
import org.slf4j.Logger;
1716
import org.slf4j.LoggerFactory;
@@ -97,7 +96,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
9796
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
9897
}
9998

100-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
99+
McpTransportContext transportContext = this.contextExtractor.extract(request);
101100

102101
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
103102
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
@@ -151,7 +150,8 @@ public static class Builder {
151150

152151
private String mcpEndpoint = "/mcp";
153152

154-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
153+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
154+
serverRequest) -> McpTransportContext.EMPTY;
155155

156156
private Builder() {
157157
// used by a static method

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import com.fasterxml.jackson.core.type.TypeReference;
88
import com.fasterxml.jackson.databind.ObjectMapper;
9-
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
9+
import io.modelcontextprotocol.common.McpTransportContext;
1010
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1111
import io.modelcontextprotocol.spec.HttpHeaders;
1212
import io.modelcontextprotocol.spec.McpError;
@@ -15,7 +15,6 @@
1515
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
1616
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
1717
import io.modelcontextprotocol.spec.ProtocolVersions;
18-
import io.modelcontextprotocol.server.McpTransportContext;
1918
import io.modelcontextprotocol.util.Assert;
2019
import io.modelcontextprotocol.util.KeepAliveScheduler;
2120

@@ -166,7 +165,7 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
166165
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
167166
}
168167

169-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
168+
McpTransportContext transportContext = this.contextExtractor.extract(request);
170169

171170
return Mono.defer(() -> {
172171
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
@@ -221,7 +220,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
221220
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
222221
}
223222

224-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
223+
McpTransportContext transportContext = this.contextExtractor.extract(request);
225224

226225
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
227226
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
@@ -309,7 +308,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
309308
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
310309
}
311310

312-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
311+
McpTransportContext transportContext = this.contextExtractor.extract(request);
313312

314313
return Mono.defer(() -> {
315314
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
@@ -402,7 +401,8 @@ public static class Builder {
402401

403402
private String mcpEndpoint = "/mcp";
404403

405-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
404+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
405+
serverRequest) -> McpTransportContext.EMPTY;
406406

407407
private boolean disallowDelete;
408408

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package io.modelcontextprotocol;
66

77
import java.time.Duration;
8+
import java.util.Map;
89

910
import org.junit.jupiter.api.AfterEach;
1011
import org.junit.jupiter.api.BeforeEach;
@@ -20,6 +21,7 @@
2021
import io.modelcontextprotocol.client.McpClient;
2122
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
2223
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
24+
import io.modelcontextprotocol.common.McpTransportContext;
2325
import io.modelcontextprotocol.server.McpServer;
2426
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
2527
import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification;
@@ -42,10 +44,8 @@ class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests
4244

4345
private WebFluxSseServerTransportProvider mcpServerTransportProvider;
4446

45-
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
46-
tc.put("important", "value");
47-
return tc;
48-
};
47+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
48+
.create(Map.of("important", "value"));
4949

5050
@Override
5151
protected void prepareClients(int port, String mcpEndpoint) {

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package io.modelcontextprotocol;
66

77
import java.time.Duration;
8+
import java.util.Map;
89

910
import org.junit.jupiter.api.AfterEach;
1011
import org.junit.jupiter.api.BeforeEach;
@@ -20,6 +21,7 @@
2021
import io.modelcontextprotocol.client.McpClient;
2122
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
2223
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
24+
import io.modelcontextprotocol.common.McpTransportContext;
2325
import io.modelcontextprotocol.server.McpServer;
2426
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
2527
import io.modelcontextprotocol.server.McpServer.SyncSpecification;
@@ -40,10 +42,8 @@ class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrati
4042

4143
private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider;
4244

43-
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
44-
tc.put("important", "value");
45-
return tc;
46-
};
45+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
46+
.create(Map.of("important", "value"));
4747

4848
@Override
4949
protected void prepareClients(int port, String mcpEndpoint) {

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
import com.fasterxml.jackson.core.type.TypeReference;
1515
import com.fasterxml.jackson.databind.ObjectMapper;
1616

17-
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
18-
import io.modelcontextprotocol.server.McpTransportContext;
17+
import io.modelcontextprotocol.common.McpTransportContext;
1918
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2019
import io.modelcontextprotocol.spec.McpError;
2120
import io.modelcontextprotocol.spec.McpSchema;
@@ -192,7 +191,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
192191
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
193192
String sseEndpoint, Duration keepAliveInterval) {
194193
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval,
195-
(serverRequest, context) -> context);
194+
(serverRequest) -> McpTransportContext.EMPTY);
196195
}
197196

198197
/**
@@ -397,8 +396,7 @@ private ServerResponse handleMessage(ServerRequest request) {
397396
}
398397

399398
try {
400-
final McpTransportContext transportContext = this.contextExtractor.extract(request,
401-
new DefaultMcpTransportContext());
399+
final McpTransportContext transportContext = this.contextExtractor.extract(request);
402400

403401
String body = request.body(String.class);
404402
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
@@ -553,7 +551,8 @@ public static class Builder {
553551

554552
private Duration keepAliveInterval;
555553

556-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
554+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
555+
serverRequest) -> McpTransportContext.EMPTY;
557556

558557
/**
559558
* Sets the JSON object mapper to use for message serialization/deserialization.

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
package io.modelcontextprotocol.server.transport;
66

77
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import io.modelcontextprotocol.common.McpTransportContext;
89
import io.modelcontextprotocol.server.McpStatelessServerHandler;
9-
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
1010
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1111
import io.modelcontextprotocol.spec.McpError;
1212
import io.modelcontextprotocol.spec.McpSchema;
1313
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
14-
import io.modelcontextprotocol.server.McpTransportContext;
1514
import io.modelcontextprotocol.util.Assert;
1615
import org.slf4j.Logger;
1716
import org.slf4j.LoggerFactory;
@@ -101,7 +100,7 @@ private ServerResponse handlePost(ServerRequest request) {
101100
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
102101
}
103102

104-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
103+
McpTransportContext transportContext = this.contextExtractor.extract(request);
105104

106105
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
107106
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
@@ -176,7 +175,8 @@ public static class Builder {
176175

177176
private String mcpEndpoint = "/mcp";
178177

179-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
178+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
179+
serverRequest) -> McpTransportContext.EMPTY;
180180

181181
private Builder() {
182182
// used by a static method

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
import com.fasterxml.jackson.core.type.TypeReference;
2424
import com.fasterxml.jackson.databind.ObjectMapper;
2525

26-
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
27-
import io.modelcontextprotocol.server.McpTransportContext;
26+
import io.modelcontextprotocol.common.McpTransportContext;
2827
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2928
import io.modelcontextprotocol.spec.HttpHeaders;
3029
import io.modelcontextprotocol.spec.McpError;
@@ -238,7 +237,7 @@ private ServerResponse handleGet(ServerRequest request) {
238237
return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM");
239238
}
240239

241-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
240+
McpTransportContext transportContext = this.contextExtractor.extract(request);
242241

243242
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
244243
return ServerResponse.badRequest().body("Session ID required in mcp-session-id header");
@@ -322,7 +321,7 @@ private ServerResponse handlePost(ServerRequest request) {
322321
.body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON"));
323322
}
324323

325-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
324+
McpTransportContext transportContext = this.contextExtractor.extract(request);
326325

327326
try {
328327
String body = request.body(String.class);
@@ -431,7 +430,7 @@ private ServerResponse handleDelete(ServerRequest request) {
431430
return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
432431
}
433432

434-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
433+
McpTransportContext transportContext = this.contextExtractor.extract(request);
435434

436435
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
437436
return ServerResponse.badRequest().body("Session ID required in mcp-session-id header");
@@ -604,7 +603,8 @@ public static class Builder {
604603

605604
private boolean disallowDelete = false;
606605

607-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
606+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
607+
serverRequest) -> McpTransportContext.EMPTY;
608608

609609
private Duration keepAliveInterval;
610610

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import static org.assertj.core.api.Assertions.assertThat;
77

88
import java.time.Duration;
9+
import java.util.Map;
910

1011
import org.apache.catalina.LifecycleException;
1112
import org.apache.catalina.LifecycleState;
@@ -26,6 +27,7 @@
2627
import io.modelcontextprotocol.client.McpClient;
2728
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
2829
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
30+
import io.modelcontextprotocol.common.McpTransportContext;
2931
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
3032
import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification;
3133
import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
@@ -40,10 +42,8 @@ class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests
4042

4143
private WebMvcSseServerTransportProvider mcpServerTransportProvider;
4244

43-
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
44-
tc.put("important", "value");
45-
return tc;
46-
};
45+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext
46+
.create(Map.of("important", "value"));
4747

4848
@Override
4949
protected void prepareClients(int port, String mcpEndpoint) {

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import static org.assertj.core.api.Assertions.assertThat;
77

88
import java.time.Duration;
9+
import java.util.Map;
910

1011
import org.apache.catalina.LifecycleException;
1112
import org.apache.catalina.LifecycleState;
@@ -26,6 +27,7 @@
2627
import io.modelcontextprotocol.client.McpClient;
2728
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
2829
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
30+
import io.modelcontextprotocol.common.McpTransportContext;
2931
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
3032
import io.modelcontextprotocol.server.McpServer.SyncSpecification;
3133
import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider;
@@ -40,10 +42,8 @@ class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegratio
4042

4143
private WebMvcStreamableServerTransportProvider mcpServerTransportProvider;
4244

43-
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
44-
tc.put("important", "value");
45-
return tc;
46-
};
45+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext
46+
.create(Map.of("important", "value"));
4747

4848
@Configuration
4949
@EnableWebMvc

0 commit comments

Comments
 (0)