Skip to content

Commit ad7b6fc

Browse files
committed
Add end-to-end McpTransportContextIntegrationTests
Signed-off-by: Daniel Garnier-Moiroux <[email protected]>
1 parent dee0c1a commit ad7b6fc

File tree

1 file changed

+241
-0
lines changed

1 file changed

+241
-0
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.common;
6+
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import io.modelcontextprotocol.client.McpClient;
9+
import io.modelcontextprotocol.client.McpClient.SyncSpec;
10+
import io.modelcontextprotocol.client.McpSyncClient;
11+
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
12+
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
13+
import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
14+
import io.modelcontextprotocol.server.McpServer;
15+
import io.modelcontextprotocol.server.McpServerFeatures;
16+
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
17+
import io.modelcontextprotocol.server.McpSyncServerExchange;
18+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
19+
import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
20+
import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport;
21+
import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
22+
import io.modelcontextprotocol.server.transport.TomcatTestUtil;
23+
import io.modelcontextprotocol.spec.McpSchema;
24+
import jakarta.servlet.Servlet;
25+
import jakarta.servlet.http.HttpServletRequest;
26+
import java.util.Map;
27+
import java.util.function.BiFunction;
28+
import java.util.function.Supplier;
29+
import org.apache.catalina.LifecycleException;
30+
import org.apache.catalina.LifecycleState;
31+
import org.apache.catalina.startup.Tomcat;
32+
import org.junit.jupiter.api.AfterEach;
33+
import org.junit.jupiter.api.Test;
34+
import org.junit.jupiter.api.Timeout;
35+
36+
import static org.assertj.core.api.Assertions.assertThat;
37+
38+
/**
39+
* Test both Client and Server {@link McpTransportContext} integration, in two steps.
40+
* <p>
41+
* First, the client calls a tool and writes data stored in a thread-local to an HTTP
42+
* header using {@link SyncSpec#transportContextProvider(Supplier)} and
43+
* {@link McpSyncHttpClientRequestCustomizer}.
44+
* <p>
45+
* Then the server reads the header with a {@link McpTransportContextExtractor} and
46+
* returns the value as the result of the tool call.
47+
*
48+
* @author Daniel Garnier-Moiroux
49+
*/
50+
@Timeout(15)
51+
public class McpTransportContextIntegrationTests {
52+
53+
private static final int PORT = TomcatTestUtil.findAvailablePort();
54+
55+
private Tomcat tomcat;
56+
57+
private static final ThreadLocal<String> CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>();
58+
59+
private static final String HEADER_NAME = "x-test";
60+
61+
private final Supplier<McpTransportContext> clientContextProvider = () -> {
62+
var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get();
63+
return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue))
64+
: McpTransportContext.EMPTY;
65+
};
66+
67+
private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body,
68+
context) -> {
69+
var headerValue = context.get("client-side-header-value");
70+
if (headerValue != null) {
71+
builder.header(HEADER_NAME, headerValue.toString());
72+
}
73+
};
74+
75+
private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = (HttpServletRequest r) -> {
76+
var headerValue = r.getHeader(HEADER_NAME);
77+
return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue))
78+
: McpTransportContext.EMPTY;
79+
};
80+
81+
private final BiFunction<McpTransportContext, McpSchema.CallToolRequest, McpSchema.CallToolResult> statelessHandler = (
82+
transportContext,
83+
request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null);
84+
85+
private final BiFunction<McpSyncServerExchange, McpSchema.CallToolRequest, McpSchema.CallToolResult> statefulHandler = (
86+
exchange, request) -> statelessHandler.apply(exchange.transportContext(), request);
87+
88+
private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport
89+
.builder()
90+
.objectMapper(new ObjectMapper())
91+
.contextExtractor(serverContextExtractor)
92+
.build();
93+
94+
private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider
95+
.builder()
96+
.objectMapper(new ObjectMapper())
97+
.contextExtractor(serverContextExtractor)
98+
.build();
99+
100+
private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider
101+
.builder()
102+
.objectMapper(new ObjectMapper())
103+
.contextExtractor(serverContextExtractor)
104+
.messageEndpoint("/message")
105+
.build();
106+
107+
private final McpSyncClient streamableClient = McpClient
108+
.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT)
109+
.httpRequestCustomizer(clientRequestCustomizer)
110+
.build())
111+
.transportContextProvider(clientContextProvider)
112+
.build();
113+
114+
private final McpSyncClient sseClient = McpClient
115+
.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT)
116+
.httpRequestCustomizer(clientRequestCustomizer)
117+
.build())
118+
.transportContextProvider(clientContextProvider)
119+
.build();
120+
121+
private final McpSchema.Tool tool = McpSchema.Tool.builder()
122+
.name("test-tool")
123+
.description("return the value of the x-test header from call tool request")
124+
.build();
125+
126+
@AfterEach
127+
public void after() {
128+
CLIENT_SIDE_HEADER_VALUE_HOLDER.remove();
129+
if (statelessServerTransport != null) {
130+
statelessServerTransport.closeGracefully().block();
131+
}
132+
if (streamableServerTransport != null) {
133+
streamableServerTransport.closeGracefully().block();
134+
}
135+
if (sseServerTransport != null) {
136+
sseServerTransport.closeGracefully().block();
137+
}
138+
stopTomcat();
139+
}
140+
141+
@Test
142+
void statelessServer() {
143+
startTomcat(statelessServerTransport);
144+
145+
var mcpServer = McpServer.sync(statelessServerTransport)
146+
.capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
147+
.tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler))
148+
.build();
149+
150+
McpSchema.InitializeResult initResult = streamableClient.initialize();
151+
assertThat(initResult).isNotNull();
152+
153+
CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value");
154+
McpSchema.CallToolResult response = streamableClient
155+
.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
156+
157+
assertThat(response).isNotNull();
158+
assertThat(response.content()).hasSize(1)
159+
.first()
160+
.extracting(McpSchema.TextContent.class::cast)
161+
.extracting(McpSchema.TextContent::text)
162+
.isEqualTo("some important value");
163+
164+
mcpServer.close();
165+
}
166+
167+
@Test
168+
void streamableServer() {
169+
startTomcat(streamableServerTransport);
170+
171+
var mcpServer = McpServer.sync(streamableServerTransport)
172+
.capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
173+
.tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler))
174+
.build();
175+
176+
McpSchema.InitializeResult initResult = streamableClient.initialize();
177+
assertThat(initResult).isNotNull();
178+
179+
CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value");
180+
McpSchema.CallToolResult response = streamableClient
181+
.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
182+
183+
assertThat(response).isNotNull();
184+
assertThat(response.content()).hasSize(1)
185+
.first()
186+
.extracting(McpSchema.TextContent.class::cast)
187+
.extracting(McpSchema.TextContent::text)
188+
.isEqualTo("some important value");
189+
190+
mcpServer.close();
191+
}
192+
193+
@Test
194+
void sseServer() {
195+
startTomcat(sseServerTransport);
196+
197+
var mcpServer = McpServer.sync(sseServerTransport)
198+
.capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
199+
.tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler))
200+
.build();
201+
202+
McpSchema.InitializeResult initResult = sseClient.initialize();
203+
assertThat(initResult).isNotNull();
204+
205+
CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value");
206+
McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
207+
208+
assertThat(response).isNotNull();
209+
assertThat(response.content()).hasSize(1)
210+
.first()
211+
.extracting(McpSchema.TextContent.class::cast)
212+
.extracting(McpSchema.TextContent::text)
213+
.isEqualTo("some important value");
214+
215+
mcpServer.close();
216+
}
217+
218+
private void startTomcat(Servlet transport) {
219+
tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport);
220+
try {
221+
tomcat.start();
222+
assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED);
223+
}
224+
catch (Exception e) {
225+
throw new RuntimeException("Failed to start Tomcat", e);
226+
}
227+
}
228+
229+
private void stopTomcat() {
230+
if (tomcat != null) {
231+
try {
232+
tomcat.stop();
233+
tomcat.destroy();
234+
}
235+
catch (LifecycleException e) {
236+
throw new RuntimeException("Failed to stop Tomcat", e);
237+
}
238+
}
239+
}
240+
241+
}

0 commit comments

Comments
 (0)