Skip to content

Commit 7a2c476

Browse files
committed
improve webmvc streamable http server tests
Signed-off-by: Christian Tzolov <[email protected]>
1 parent 8e01f41 commit 7a2c476

File tree

4 files changed

+146
-18
lines changed

4 files changed

+146
-18
lines changed

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ public RouterFunction<ServerResponse> getRouterFunction() {
205205
}
206206

207207
/**
208-
* Handles GET requests for SSE connections and message replay.
208+
* Setup the listening SSE connections and message replay.
209209
* @param request The incoming server request
210210
* @return A ServerResponse configured for SSE communication, or an error response
211211
*/
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
* @author Christian Tzolov
2929
*/
3030
@Timeout(15) // Giving extra time beyond the client timeout
31-
class WebMcpStreamableMcpAsyncServerTests extends AbstractMcpAsyncServerTests {
31+
class WebMcpStreamableAsyncServerTransportTests extends AbstractMcpAsyncServerTests {
3232

3333
private static final int PORT = TestUtil.findAvailablePort();
3434

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server;
6+
7+
import org.apache.catalina.Context;
8+
import org.apache.catalina.LifecycleException;
9+
import org.apache.catalina.startup.Tomcat;
10+
import org.junit.jupiter.api.Timeout;
11+
import org.springframework.context.annotation.Bean;
12+
import org.springframework.context.annotation.Configuration;
13+
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
14+
import org.springframework.web.servlet.DispatcherServlet;
15+
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
16+
import org.springframework.web.servlet.function.RouterFunction;
17+
import org.springframework.web.servlet.function.ServerResponse;
18+
19+
import com.fasterxml.jackson.databind.ObjectMapper;
20+
21+
import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider;
22+
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
23+
import reactor.netty.DisposableServer;
24+
25+
/**
26+
* Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}.
27+
*
28+
* @author Christian Tzolov
29+
*/
30+
@Timeout(15) // Giving extra time beyond the client timeout
31+
class WebMcpStreamableSyncServerTransportTests extends AbstractMcpSyncServerTests {
32+
33+
private static final int PORT = TestUtil.findAvailablePort();
34+
35+
private static final String MCP_ENDPOINT = "/mcp";
36+
37+
private DisposableServer httpServer;
38+
39+
private AnnotationConfigWebApplicationContext appContext;
40+
41+
private Tomcat tomcat;
42+
43+
private McpStreamableServerTransportProvider transportProvider;
44+
45+
@Configuration
46+
@EnableWebMvc
47+
static class TestConfig {
48+
49+
@Bean
50+
public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() {
51+
return WebMvcStreamableServerTransportProvider.builder()
52+
.objectMapper(new ObjectMapper())
53+
.mcpEndpoint(MCP_ENDPOINT)
54+
.build();
55+
}
56+
57+
@Bean
58+
public RouterFunction<ServerResponse> routerFunction(
59+
WebMvcStreamableServerTransportProvider transportProvider) {
60+
return transportProvider.getRouterFunction();
61+
}
62+
63+
}
64+
65+
private McpStreamableServerTransportProvider createMcpTransportProvider() {
66+
// Set up Tomcat first
67+
tomcat = new Tomcat();
68+
tomcat.setPort(PORT);
69+
70+
// Set Tomcat base directory to java.io.tmpdir to avoid permission issues
71+
String baseDir = System.getProperty("java.io.tmpdir");
72+
tomcat.setBaseDir(baseDir);
73+
74+
// Use the same directory for document base
75+
Context context = tomcat.addContext("", baseDir);
76+
77+
// Create and configure Spring WebMvc context
78+
appContext = new AnnotationConfigWebApplicationContext();
79+
appContext.register(TestConfig.class);
80+
appContext.setServletContext(context.getServletContext());
81+
appContext.refresh();
82+
83+
// Get the transport from Spring context
84+
transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class);
85+
86+
// Create DispatcherServlet with our Spring context
87+
DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
88+
89+
// Add servlet to Tomcat and get the wrapper
90+
var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
91+
wrapper.setLoadOnStartup(1);
92+
context.addServletMappingDecoded("/*", "dispatcherServlet");
93+
94+
try {
95+
tomcat.start();
96+
tomcat.getConnector(); // Create and start the connector
97+
}
98+
catch (LifecycleException e) {
99+
throw new RuntimeException("Failed to start Tomcat", e);
100+
}
101+
102+
return transportProvider;
103+
}
104+
105+
@Override
106+
protected McpServer.SyncSpecification<?> prepareSyncServerBuilder() {
107+
return McpServer.sync(createMcpTransportProvider());
108+
}
109+
110+
@Override
111+
protected void onStart() {
112+
}
113+
114+
@Override
115+
protected void onClose() {
116+
if (httpServer != null) {
117+
httpServer.disposeNow();
118+
}
119+
}
120+
121+
}

mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.Map;
2020
import java.util.concurrent.ConcurrentHashMap;
21+
import java.util.concurrent.CountDownLatch;
2122
import java.util.concurrent.TimeUnit;
2223
import java.util.concurrent.atomic.AtomicReference;
2324
import java.util.function.Function;
@@ -35,6 +36,8 @@
3536
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
3637
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
3738
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
39+
import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
40+
import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
3841
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
3942
import io.modelcontextprotocol.spec.McpSchema.ModelPreferences;
4043
import io.modelcontextprotocol.spec.McpSchema.Role;
@@ -489,51 +492,52 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) {
489492
@ValueSource(strings = { "httpclient", "webflux" })
490493
void testCreateElicitationWithRequestTimeoutFail(String clientType) {
491494

495+
var latch = new CountDownLatch(1);
496+
492497
var clientBuilder = clientBuilders.get(clientType);
493498

494-
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler = request -> {
499+
Function<ElicitRequest, ElicitResult> elicitationHandler = request -> {
495500
assertThat(request.message()).isNotEmpty();
496501
assertThat(request.requestedSchema()).isNotNull();
502+
497503
try {
498-
TimeUnit.SECONDS.sleep(2);
504+
if (!latch.await(2, TimeUnit.SECONDS)) {
505+
throw new RuntimeException("Timeout waiting for elicitation processing");
506+
}
499507
}
500508
catch (InterruptedException e) {
501509
throw new RuntimeException(e);
502510
}
503-
return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT,
504-
Map.of("message", request.message()));
511+
return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message()));
505512
};
506513

507514
var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
508515
.capabilities(ClientCapabilities.builder().elicitation().build())
509516
.elicitation(elicitationHandler)
510517
.build();
511518

512-
CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
513-
null);
519+
CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
520+
521+
AtomicReference<ElicitResult> resultRef = new AtomicReference<>();
514522

515523
McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
516524
.tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
517525
.callHandler((exchange, request) -> {
518526

519-
var elicitationRequest = McpSchema.ElicitRequest.builder()
527+
var elicitationRequest = ElicitRequest.builder()
520528
.message("Test message")
521529
.requestedSchema(
522530
Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string"))))
523531
.build();
524532

525-
StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
526-
assertThat(result).isNotNull();
527-
assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT);
528-
assertThat(result.content().get("message")).isEqualTo("Test message");
529-
}).verifyComplete();
530-
531-
return Mono.just(callResponse);
533+
return exchange.createElicitation(elicitationRequest)
534+
.doOnNext(resultRef::set)
535+
.then(Mono.just(callResponse));
532536
})
533537
.build();
534538

535539
var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0")
536-
.requestTimeout(Duration.ofSeconds(1))
540+
.requestTimeout(Duration.ofSeconds(1)) // 1 second.
537541
.tools(tool)
538542
.build();
539543

@@ -542,7 +546,10 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) {
542546

543547
assertThatExceptionOfType(McpError.class).isThrownBy(() -> {
544548
mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
545-
}).withMessageContaining("Timeout");
549+
}).withMessageContaining("within 1000ms");
550+
551+
ElicitResult elicitResult = resultRef.get();
552+
assertThat(elicitResult).isNull();
546553

547554
mcpClient.closeGracefully();
548555
mcpServer.closeGracefully().block();

0 commit comments

Comments
 (0)