serverContextExtractor = (ServerRequest r) -> {
+ var headerValue = r.headers().firstHeader(HEADER_NAME);
+ return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue))
+ : McpTransportContext.EMPTY;
+ };
+
+ // Server transports
+ private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .build();
+
+ private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider
+ .builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .build();
+
+ private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .messageEndpoint("/mcp/message")
+ .build();
+
+ // Async clients
+ private final McpAsyncClient asyncStreamableClient = McpClient
+ .async(WebClientStreamableHttpTransport
+ .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider))
+ .build())
+ .build();
+
+ private final McpAsyncClient asyncSseClient = McpClient
+ .async(WebFluxSseClientTransport
+ .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider))
+ .build())
+ .build();
+
+ private DisposableServer httpServer;
+
+ @AfterEach
+ public void after() {
+ if (statelessServerTransport != null) {
+ statelessServerTransport.closeGracefully().block();
+ }
+ if (streamableServerTransport != null) {
+ streamableServerTransport.closeGracefully().block();
+ }
+ if (sseServerTransport != null) {
+ sseServerTransport.closeGracefully().block();
+ }
+ if (asyncStreamableClient != null) {
+ asyncStreamableClient.closeGracefully().block();
+ }
+ if (asyncSseClient != null) {
+ asyncSseClient.closeGracefully().block();
+ }
+ stopHttpServer();
+ }
+
+ @Test
+ void asyncClientStatelessServer() {
+
+ startHttpServer(statelessServerTransport.getRouterFunction());
+
+ var mcpServer = McpServer.async(statelessServerTransport)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler))
+ .build();
+
+ StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> {
+ assertThat(initResult).isNotNull();
+ }).verifyComplete();
+
+ // Test tool call with context
+ StepVerifier
+ .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()))
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY,
+ McpTransportContext.create(Map.of("client-side-header-value", "some important value")))))
+ .assertNext(response -> {
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+ })
+ .verifyComplete();
+
+ mcpServer.close();
+ }
+
+ @Test
+ void asyncClientStreamableServer() {
+
+ startHttpServer(streamableServerTransport.getRouterFunction());
+
+ var mcpServer = McpServer.async(streamableServerTransport)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler))
+ .build();
+
+ StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> {
+ assertThat(initResult).isNotNull();
+ }).verifyComplete();
+
+ // Test tool call with context
+ StepVerifier
+ .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()))
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY,
+ McpTransportContext.create(Map.of("client-side-header-value", "some important value")))))
+ .assertNext(response -> {
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+ })
+ .verifyComplete();
+
+ mcpServer.close();
+ }
+
+ @Test
+ void asyncClientSseServer() {
+
+ startHttpServer(sseServerTransport.getRouterFunction());
+
+ var mcpServer = McpServer.async(sseServerTransport)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler))
+ .build();
+
+ StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> {
+ assertThat(initResult).isNotNull();
+ }).verifyComplete();
+
+ // Test tool call with context
+ StepVerifier
+ .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()))
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY,
+ McpTransportContext.create(Map.of("client-side-header-value", "some important value")))))
+ .assertNext(response -> {
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+ })
+ .verifyComplete();
+
+ mcpServer.close();
+ }
+
+ private void startHttpServer(RouterFunction> routerFunction) {
+
+ HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction);
+ ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler);
+ this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow();
+ }
+
+ private void stopHttpServer() {
+ if (httpServer != null) {
+ httpServer.disposeNow();
+ }
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java
new file mode 100644
index 000000000..865192489
--- /dev/null
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java
@@ -0,0 +1,273 @@
+/*
+ * Copyright 2024-2025 the original author or authors.
+ */
+
+package io.modelcontextprotocol.common;
+
+import java.util.Map;
+import java.util.function.BiFunction;
+import java.util.function.Supplier;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.McpSyncClient;
+import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
+import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
+import io.modelcontextprotocol.server.McpServer;
+import io.modelcontextprotocol.server.McpServerFeatures;
+import io.modelcontextprotocol.server.McpStatelessServerFeatures;
+import io.modelcontextprotocol.server.McpSyncServerExchange;
+import io.modelcontextprotocol.server.McpTransportContextExtractor;
+import io.modelcontextprotocol.server.TestUtil;
+import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
+import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport;
+import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider;
+import io.modelcontextprotocol.spec.McpSchema;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import reactor.core.publisher.Mono;
+import reactor.netty.DisposableServer;
+import reactor.netty.http.server.HttpServer;
+
+import org.springframework.http.server.reactive.HttpHandler;
+import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
+import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.reactive.function.server.RouterFunction;
+import org.springframework.web.reactive.function.server.RouterFunctions;
+import org.springframework.web.reactive.function.server.ServerRequest;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Integration tests for {@link McpTransportContext} propagation between MCP client and
+ * server using synchronous operations in a Spring WebFlux environment.
+ *
+ * This test class validates the end-to-end flow of transport context propagation across
+ * different WebFlux-based MCP transport implementations
+ *
+ *
+ * The test scenario follows these steps:
+ *
+ * - The client stores a value in a thread-local variable
+ * - The client's transport context provider reads this value and includes it in the MCP
+ * context
+ * - A WebClient filter extracts the context value and adds it as an HTTP header
+ * (x-test)
+ * - The server's {@link McpTransportContextExtractor} reads the header from the
+ * request
+ * - The server returns the header value as the tool call result, validating the
+ * round-trip
+ *
+ *
+ *
+ * This test demonstrates how custom context can be propagated through HTTP headers in a
+ * reactive WebFlux environment, enabling features like authentication tokens, correlation
+ * IDs, or other metadata to flow between MCP client and server.
+ *
+ * @author Daniel Garnier-Moiroux
+ * @author Christian Tzolov
+ * @since 1.0.0
+ * @see McpTransportContext
+ * @see McpTransportContextExtractor
+ * @see WebFluxStatelessServerTransport
+ * @see WebFluxStreamableServerTransportProvider
+ * @see WebFluxSseServerTransportProvider
+ */
+@Timeout(15)
+public class SyncServerMcpTransportContextIntegrationTests {
+
+ private static final int PORT = TestUtil.findAvailablePort();
+
+ private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>();
+
+ private static final String HEADER_NAME = "x-test";
+
+ private final Supplier clientContextProvider = () -> {
+ var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get();
+ return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue))
+ : McpTransportContext.EMPTY;
+ };
+
+ private final BiFunction statelessHandler = (
+ transportContext, request) -> {
+ return new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null);
+ };
+
+ private final BiFunction statefulHandler = (
+ exchange, request) -> statelessHandler.apply(exchange.transportContext(), request);
+
+ private final McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> {
+ var headerValue = r.headers().firstHeader(HEADER_NAME);
+ return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue))
+ : McpTransportContext.EMPTY;
+ };
+
+ private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .build();
+
+ private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider
+ .builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .build();
+
+ private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .messageEndpoint("/mcp/message")
+ .build();
+
+ private final McpSyncClient streamableClient = McpClient
+ .sync(WebClientStreamableHttpTransport.builder(WebClient.builder()
+ .baseUrl("http://localhost:" + PORT)
+ .filter((request, next) -> Mono.deferContextual(ctx -> {
+ var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY);
+ // // do stuff with the context
+ var headerValue = context.get("client-side-header-value");
+ if (headerValue == null) {
+ return next.exchange(request);
+ }
+ var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build();
+ return next.exchange(reqWithHeader);
+ }))).build())
+ .transportContextProvider(clientContextProvider)
+ .build();
+
+ private final McpSyncClient sseClient = McpClient.sync(WebFluxSseClientTransport.builder(WebClient.builder()
+ .baseUrl("http://localhost:" + PORT)
+ .filter((request, next) -> Mono.deferContextual(ctx -> {
+ var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY);
+ // // do stuff with the context
+ var headerValue = context.get("client-side-header-value");
+ if (headerValue == null) {
+ return next.exchange(request);
+ }
+ var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build();
+ return next.exchange(reqWithHeader);
+ }))).build()).transportContextProvider(clientContextProvider).build();
+
+ private final McpSchema.Tool tool = McpSchema.Tool.builder()
+ .name("test-tool")
+ .description("return the value of the x-test header from call tool request")
+ .build();
+
+ private DisposableServer httpServer;
+
+ @AfterEach
+ public void after() {
+ CLIENT_SIDE_HEADER_VALUE_HOLDER.remove();
+ if (statelessServerTransport != null) {
+ statelessServerTransport.closeGracefully().block();
+ }
+ if (streamableServerTransport != null) {
+ streamableServerTransport.closeGracefully().block();
+ }
+ if (sseServerTransport != null) {
+ sseServerTransport.closeGracefully().block();
+ }
+ if (streamableClient != null) {
+ streamableClient.closeGracefully();
+ }
+ if (sseClient != null) {
+ sseClient.closeGracefully();
+ }
+ stopHttpServer();
+ }
+
+ @Test
+ void statelessServer() {
+
+ startHttpServer(statelessServerTransport.getRouterFunction());
+
+ var mcpServer = McpServer.sync(statelessServerTransport)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler))
+ .build();
+
+ McpSchema.InitializeResult initResult = streamableClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value");
+ McpSchema.CallToolResult response = streamableClient
+ .callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+
+ mcpServer.close();
+ }
+
+ @Test
+ void streamableServer() {
+
+ startHttpServer(streamableServerTransport.getRouterFunction());
+
+ var mcpServer = McpServer.sync(streamableServerTransport)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler))
+ .build();
+
+ McpSchema.InitializeResult initResult = streamableClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value");
+ McpSchema.CallToolResult response = streamableClient
+ .callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+
+ mcpServer.close();
+ }
+
+ @Test
+ void sseServer() {
+ startHttpServer(sseServerTransport.getRouterFunction());
+
+ var mcpServer = McpServer.sync(sseServerTransport)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler))
+ .build();
+
+ McpSchema.InitializeResult initResult = sseClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value");
+ McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+
+ mcpServer.close();
+ }
+
+ private void startHttpServer(RouterFunction> routerFunction) {
+
+ HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction);
+ ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler);
+ this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow();
+ }
+
+ private void stopHttpServer() {
+ if (httpServer != null) {
+ httpServer.disposeNow();
+ }
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml
index ea262d3a1..170309211 100644
--- a/mcp-spring/mcp-spring-webmvc/pom.xml
+++ b/mcp-spring/mcp-spring-webmvc/pom.xml
@@ -41,7 +41,7 @@
test
-
+
io.modelcontextprotocol.sdk
mcp-spring-webflux
0.12.0-SNAPSHOT
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java
new file mode 100644
index 000000000..1f5f1cc0c
--- /dev/null
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java
@@ -0,0 +1,306 @@
+/*
+ * Copyright 2024-2025 the original author or authors.
+ */
+
+package io.modelcontextprotocol.common;
+
+import java.util.Map;
+import java.util.function.BiFunction;
+import java.util.function.Supplier;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.McpSyncClient;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
+import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
+import io.modelcontextprotocol.server.McpServer;
+import io.modelcontextprotocol.server.McpServerFeatures;
+import io.modelcontextprotocol.server.McpStatelessServerFeatures;
+import io.modelcontextprotocol.server.McpStatelessSyncServer;
+import io.modelcontextprotocol.server.McpSyncServer;
+import io.modelcontextprotocol.server.McpSyncServerExchange;
+import io.modelcontextprotocol.server.McpTransportContextExtractor;
+import io.modelcontextprotocol.server.TestUtil;
+import io.modelcontextprotocol.server.TomcatTestUtil;
+import io.modelcontextprotocol.server.TomcatTestUtil.TomcatServer;
+import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
+import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport;
+import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider;
+import io.modelcontextprotocol.spec.McpSchema;
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.LifecycleState;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.servlet.config.annotation.EnableWebMvc;
+import org.springframework.web.servlet.function.RouterFunction;
+import org.springframework.web.servlet.function.ServerRequest;
+import org.springframework.web.servlet.function.ServerResponse;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Integration tests for {@link McpTransportContext} propagation between MCP clients and
+ * servers using Spring WebMVC transport implementations.
+ *
+ *
+ * This test class validates the end-to-end flow of transport context propagation across
+ * different MCP transport mechanisms in a Spring WebMVC environment. It demonstrates how
+ * contextual information can be passed from client to server through HTTP headers and
+ * properly extracted and utilized on the server side.
+ *
+ *
Transport Types Tested
+ *
+ * - Stateless: Tests context propagation with
+ * {@link WebMvcStatelessServerTransport} where each request is independent
+ * - Streamable HTTP: Tests context propagation with
+ * {@link WebMvcStreamableServerTransportProvider} supporting stateful server
+ * sessions
+ * - Server-Sent Events (SSE): Tests context propagation with
+ * {@link WebMvcSseServerTransportProvider} for long-lived connections
+ *
+ *
+ * @author Daniel Garnier-Moiroux
+ * @author Christian Tzolov
+ */
+@Timeout(15)
+public class McpTransportContextIntegrationTests {
+
+ private static final int PORT = TestUtil.findAvailablePort();
+
+ private TomcatServer tomcatServer;
+
+ private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>();
+
+ private static final String HEADER_NAME = "x-test";
+
+ private final Supplier clientContextProvider = () -> {
+ var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get();
+ return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue))
+ : McpTransportContext.EMPTY;
+ };
+
+ private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body,
+ context) -> {
+ var headerValue = context.get("client-side-header-value");
+ if (headerValue != null) {
+ builder.header(HEADER_NAME, headerValue.toString());
+ }
+ };
+
+ private static final BiFunction statelessHandler = (
+ transportContext,
+ request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null);
+
+ private static final BiFunction statefulHandler = (
+ exchange, request) -> statelessHandler.apply(exchange.transportContext(), request);
+
+ private static McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> {
+ String headerValue = r.servletRequest().getHeader(HEADER_NAME);
+ return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue))
+ : McpTransportContext.EMPTY;
+ };
+
+ private final McpSyncClient streamableClient = McpClient
+ .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT)
+ .httpRequestCustomizer(clientRequestCustomizer)
+ .build())
+ .transportContextProvider(clientContextProvider)
+ .build();
+
+ private final McpSyncClient sseClient = McpClient
+ .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT)
+ .httpRequestCustomizer(clientRequestCustomizer)
+ .build())
+ .transportContextProvider(clientContextProvider)
+ .build();
+
+ private static final McpSchema.Tool tool = McpSchema.Tool.builder()
+ .name("test-tool")
+ .description("return the value of the x-test header from call tool request")
+ .build();
+
+ @AfterEach
+ public void after() {
+ CLIENT_SIDE_HEADER_VALUE_HOLDER.remove();
+ if (streamableClient != null) {
+ streamableClient.closeGracefully();
+ }
+ if (sseClient != null) {
+ sseClient.closeGracefully();
+ }
+ stopTomcat();
+ }
+
+ @Test
+ void statelessServer() {
+ startTomcat(TestStatelessConfig.class);
+
+ McpSchema.InitializeResult initResult = streamableClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value");
+ McpSchema.CallToolResult response = streamableClient
+ .callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+ }
+
+ @Test
+ void streamableServer() {
+
+ startTomcat(TestStreamableHttpConfig.class);
+
+ McpSchema.InitializeResult initResult = streamableClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value");
+ McpSchema.CallToolResult response = streamableClient
+ .callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+ }
+
+ @Test
+ void sseServer() {
+ startTomcat(TestSseConfig.class);
+
+ McpSchema.InitializeResult initResult = sseClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value");
+ McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+ }
+
+ private void startTomcat(Class> componentClass) {
+ tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, componentClass);
+ try {
+ tomcatServer.tomcat().start();
+ assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+ }
+
+ private void stopTomcat() {
+ if (tomcatServer != null && tomcatServer.tomcat() != null) {
+ try {
+ tomcatServer.tomcat().stop();
+ tomcatServer.tomcat().destroy();
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to stop Tomcat", e);
+ }
+ }
+ }
+
+ @Configuration
+ @EnableWebMvc
+ static class TestStatelessConfig {
+
+ @Bean
+ public WebMvcStatelessServerTransport webMvcStatelessServerTransport() {
+
+ return WebMvcStatelessServerTransport.builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(WebMvcStatelessServerTransport transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ @Bean
+ public McpStatelessSyncServer mcpStatelessServer(WebMvcStatelessServerTransport transportProvider) {
+ return McpServer.sync(transportProvider)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler))
+ .build();
+ }
+
+ }
+
+ @Configuration
+ @EnableWebMvc
+ static class TestStreamableHttpConfig {
+
+ @Bean
+ public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransport() {
+
+ return WebMvcStreamableServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(
+ WebMvcStreamableServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ @Bean
+ public McpSyncServer mcpStreamableServer(WebMvcStreamableServerTransportProvider transportProvider) {
+ return McpServer.sync(transportProvider)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler))
+ .build();
+ }
+
+ }
+
+ @Configuration
+ @EnableWebMvc
+ static class TestSseConfig {
+
+ @Bean
+ public WebMvcSseServerTransportProvider webMvcSseServerTransport() {
+
+ return WebMvcSseServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .messageEndpoint("/mcp/message")
+ .build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ @Bean
+ public McpSyncServer mcpSseServer(WebMvcSseServerTransportProvider transportProvider) {
+ return McpServer.sync(transportProvider)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler))
+ .build();
+
+ }
+
+ }
+
+}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java
new file mode 100644
index 000000000..fb19c62f7
--- /dev/null
+++ b/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java
@@ -0,0 +1,284 @@
+/*
+ * Copyright 2024-2025 the original author or authors.
+ */
+
+package io.modelcontextprotocol.common;
+
+import java.util.Map;
+import java.util.function.BiFunction;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.McpAsyncClient;
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
+import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer;
+import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
+import io.modelcontextprotocol.server.McpAsyncServerExchange;
+import io.modelcontextprotocol.server.McpServer;
+import io.modelcontextprotocol.server.McpServerFeatures;
+import io.modelcontextprotocol.server.McpStatelessServerFeatures;
+import io.modelcontextprotocol.server.McpTransportContextExtractor;
+import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
+import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport;
+import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
+import io.modelcontextprotocol.server.transport.TomcatTestUtil;
+import io.modelcontextprotocol.spec.McpSchema;
+import jakarta.servlet.Servlet;
+import jakarta.servlet.http.HttpServletRequest;
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.LifecycleState;
+import org.apache.catalina.startup.Tomcat;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Integration tests for {@link McpTransportContext} propagation between MCP clients and
+ * async servers.
+ *
+ *
+ * This test class validates the end-to-end flow of transport context propagation in MCP
+ * communication, demonstrating how contextual information can be passed from client to
+ * server through HTTP headers and accessed within server-side handlers.
+ *
+ *
Test Scenarios
+ *
+ * The tests cover multiple transport configurations with async servers:
+ *
+ * - Stateless server with async streamable HTTP clients
+ * - Streamable server with async streamable HTTP clients
+ * - SSE (Server-Sent Events) server with async SSE clients
+ *
+ *
+ * Context Propagation Flow
+ *
+ * - Client-side: Context data is stored in the Reactor Context and injected into HTTP
+ * headers via {@link McpSyncHttpClientRequestCustomizer}
+ * - Transport: The context travels as HTTP headers (specifically "x-test" header in
+ * these tests)
+ * - Server-side: A {@link McpTransportContextExtractor} extracts the header value and
+ * makes it available to request handlers through {@link McpTransportContext}
+ * - Verification: The server echoes back the received context value as the tool call
+ * result
+ *
+ *
+ *
+ * All tests use an embedded Tomcat server running on a dynamically allocated port to
+ * ensure isolation and prevent port conflicts during parallel test execution.
+ *
+ * @author Daniel Garnier-Moiroux
+ * @author Christian Tzolov
+ */
+@Timeout(15)
+public class AsyncServerMcpTransportContextIntegrationTests {
+
+ private static final int PORT = TomcatTestUtil.findAvailablePort();
+
+ private Tomcat tomcat;
+
+ private static final String HEADER_NAME = "x-test";
+
+ private final McpAsyncHttpClientRequestCustomizer asyncClientRequestCustomizer = (builder, method, endpoint, body,
+ context) -> {
+ var headerValue = context.get("client-side-header-value");
+ if (headerValue != null) {
+ builder.header(HEADER_NAME, headerValue.toString());
+ }
+ return Mono.just(builder);
+ };
+
+ private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> {
+ var headerValue = r.getHeader(HEADER_NAME);
+ return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue))
+ : McpTransportContext.EMPTY;
+ };
+
+ private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport
+ .builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .build();
+
+ private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider
+ .builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .build();
+
+ private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider
+ .builder()
+ .objectMapper(new ObjectMapper())
+ .contextExtractor(serverContextExtractor)
+ .messageEndpoint("/message")
+ .build();
+
+ private final McpAsyncClient asyncStreamableClient = McpClient
+ .async(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT)
+ .asyncHttpRequestCustomizer(asyncClientRequestCustomizer)
+ .build())
+ .build();
+
+ private final McpAsyncClient asyncSseClient = McpClient
+ .async(HttpClientSseClientTransport.builder("http://localhost:" + PORT)
+ .asyncHttpRequestCustomizer(asyncClientRequestCustomizer)
+ .build())
+ .build();
+
+ private final McpSchema.Tool tool = McpSchema.Tool.builder()
+ .name("test-tool")
+ .description("return the value of the x-test header from call tool request")
+ .build();
+
+ private final BiFunction> asyncStatelessHandler = (
+ transportContext, request) -> {
+ return Mono
+ .just(new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null));
+ };
+
+ private final BiFunction> asyncStatefulHandler = (
+ exchange, request) -> {
+ return asyncStatelessHandler.apply(exchange.transportContext(), request);
+ };
+
+ @AfterEach
+ public void after() {
+ if (statelessServerTransport != null) {
+ statelessServerTransport.closeGracefully().block();
+ }
+ if (streamableServerTransport != null) {
+ streamableServerTransport.closeGracefully().block();
+ }
+ if (sseServerTransport != null) {
+ sseServerTransport.closeGracefully().block();
+ }
+ if (asyncStreamableClient != null) {
+ asyncStreamableClient.closeGracefully().block();
+ }
+ if (asyncSseClient != null) {
+ asyncSseClient.closeGracefully().block();
+ }
+ stopTomcat();
+ }
+
+ @Test
+ void asyncClinetStatelessServer() {
+ startTomcat(statelessServerTransport);
+
+ var mcpServer = McpServer.async(statelessServerTransport)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler))
+ .build();
+
+ StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> {
+ assertThat(initResult).isNotNull();
+ }).verifyComplete();
+
+ // Test tool call with context
+ StepVerifier
+ .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()))
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY,
+ McpTransportContext.create(Map.of("client-side-header-value", "some important value")))))
+ .assertNext(response -> {
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+ })
+ .verifyComplete();
+
+ mcpServer.close();
+ }
+
+ @Test
+ void asyncClientStreamableServer() {
+ startTomcat(streamableServerTransport);
+
+ var mcpServer = McpServer.async(streamableServerTransport)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler))
+ .build();
+
+ StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> {
+ assertThat(initResult).isNotNull();
+ }).verifyComplete();
+
+ // Test tool call with context
+ StepVerifier
+ .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()))
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY,
+ McpTransportContext.create(Map.of("client-side-header-value", "some important value")))))
+ .assertNext(response -> {
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+ })
+ .verifyComplete();
+
+ mcpServer.close();
+ }
+
+ @Test
+ void asyncClientSseServer() {
+ startTomcat(sseServerTransport);
+
+ var mcpServer = McpServer.async(sseServerTransport)
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler))
+ .build();
+
+ StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> {
+ assertThat(initResult).isNotNull();
+ }).verifyComplete();
+
+ // Test tool call with context
+ StepVerifier
+ .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()))
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY,
+ McpTransportContext.create(Map.of("client-side-header-value", "some important value")))))
+ .assertNext(response -> {
+ assertThat(response).isNotNull();
+ assertThat(response.content()).hasSize(1)
+ .first()
+ .extracting(McpSchema.TextContent.class::cast)
+ .extracting(McpSchema.TextContent::text)
+ .isEqualTo("some important value");
+ })
+ .verifyComplete();
+
+ mcpServer.close();
+ }
+
+ private void startTomcat(Servlet transport) {
+ tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport);
+ try {
+ tomcat.start();
+ assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED);
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+ }
+
+ private void stopTomcat() {
+ if (tomcat != null) {
+ try {
+ tomcat.stop();
+ tomcat.destroy();
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to stop Tomcat", e);
+ }
+ }
+ }
+
+}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java
similarity index 97%
rename from mcp/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java
rename to mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java
index 8d75b8479..42747f717 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java
@@ -48,7 +48,7 @@
* @author Daniel Garnier-Moiroux
*/
@Timeout(15)
-public class McpTransportContextIntegrationTests {
+public class SyncServerMcpTransportContextIntegrationTests {
private static final int PORT = TomcatTestUtil.findAvailablePort();
@@ -135,6 +135,12 @@ public void after() {
if (sseServerTransport != null) {
sseServerTransport.closeGracefully().block();
}
+ if (streamableClient != null) {
+ streamableClient.closeGracefully();
+ }
+ if (sseClient != null) {
+ sseClient.closeGracefully();
+ }
stopTomcat();
}