From f4be7ca6c40d5d9d74a3f78862864ffb55576628 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Fri, 8 Aug 2025 19:01:05 +0100 Subject: [PATCH 1/3] refactor: extract common integration test logic into abstract base classes - Move duplicated test methods from WebFlux and WebMvc integration test classes to abstract base classes - WebFluxSseIntegrationTests, WebFluxStreamableIntegrationTests now extend AbstractMcpClientServerIntegrationTests - WebFluxStatelessIntegrationTests, WebMvcStatelessIntegrationTests now extend AbstractStatelessIntegrationTests - Each concrete test class now only implements transport-specific setup methods (prepareClients, prepareAsyncServerBuilder, prepareSyncServerBuilder) - Eliminates ~1300+ lines of duplicated test code across multiple files - Improves maintainability by centralizing test logic in reusable base classes Signed-off-by: Christian Tzolov --- .../WebFluxSseIntegrationTests.java | 1457 +--------------- .../WebFluxStatelessIntegrationTests.java | 455 +---- .../WebFluxStreamableIntegrationTests.java | 1479 +---------------- .../server/WebMvcSseIntegrationTests.java | 5 +- .../WebMvcStatelessIntegrationTests.java | 88 +- ...stractMcpClientServerIntegrationTests.java | 281 +++- ...stractMcpClientServerIntegrationTests.java | 277 ++- 7 files changed, 678 insertions(+), 3364 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 8ce714f94..1b82366f9 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,33 +4,10 @@ package io.modelcontextprotocol; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; - -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.stream.Collectors; - import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; @@ -40,36 +17,14 @@ import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; -import io.modelcontextprotocol.spec.McpSchema.CompleteResult; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; -import io.modelcontextprotocol.spec.McpSchema.ElicitResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptArgument; -import io.modelcontextprotocol.spec.McpSchema.PromptReference; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import net.javacrumbs.jsonunit.core.Option; -import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; -import reactor.test.StepVerifier; -class WebFluxSseIntegrationTests { +class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); @@ -81,20 +36,8 @@ class WebFluxSseIntegrationTests { private WebFluxSseServerTransportProvider mcpServerTransportProvider; - ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); - - @BeforeEach - public void before() { - - this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build(); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + @Override + protected void prepareClients(int port, String mcpEndpoint) { clientBuilders.put("httpclient", McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) @@ -105,1389 +48,39 @@ public void before() { .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build())); - - } - - @AfterEach - public void after() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithoutSamplingCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) - .thenReturn(mock(CallToolResult.class))) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build();) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(createMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - } - mcpServer.closeGracefully().block(); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(craeteMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .requestTimeout(Duration.ofSeconds(4)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - } - - mcpServer.closeGracefully().block(); + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpServerTransportProvider); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .build(); - - return exchange.createMessage(craeteMessageRequest).thenReturn(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .requestTimeout(Duration.ofSeconds(1)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); - - } - - mcpServer.closeGracefully().block(); + @Override + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpServerTransportProvider); } - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithoutElicitationCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.createElicitation(mock(ElicitRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithRequestTimeoutFail(String clientType) { - - var latch = new CountDownLatch(1); - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - try { - if (!latch.await(2, TimeUnit.SECONDS)) { - throw new RuntimeException("Timeout waiting for elicitation processing"); - } - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) // 1 second. - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithoutCapability(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - // Create client without roots capability - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsNotificationWithEmptyRootsList(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithMultipleHandlers(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsServerCloseWithActiveSubscription(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolCallSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolListChangeHandlingSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema)) - .callHandler((exchange, request) -> callResponse) - .build(); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testInitialize(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testLoggingNotification(String clientType) throws InterruptedException { - int expectedNotificationsCount = 3; - CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - // Create and send notifications with different levels - - //@formatter:off - return exchange // This should be filtered out (DEBUG < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build()) - .then(exchange // This should be sent (NOTICE >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.NOTICE) - .logger("test-logger") - .data("Notice message") - .build())) - .then(exchange // This should be sent (ERROR > NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build())) - .then(exchange // This should be filtered out (INFO < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Another info message") - .build())) - .then(exchange // This should be sent (ERROR >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Another error message") - .build())) - .thenReturn(new CallToolResult("Logging test completed", false)); - //@formatter:on - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) - .build(); - - try ( - // Create client with logging notification handler - var mcpClient = clientBuilder.loggingConsumer(notification -> { - receivedNotifications.add(notification); - latch.countDown(); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Set minimum logging level to NOTICE - mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); - - // Call the tool that sends logging notifications - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - - assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); - - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - } - mcpServer.close(); - } - - // --------------------------------------- - // Progress Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testProgressNotification(String clientType) throws InterruptedException { - int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress - // token - CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(McpSchema.Tool.builder() - .name("progress-test") - .description("Test progress notifications") - .inputSchema(emptyJsonSchema) - .build()) - .callHandler((exchange, request) -> { - - // Create and send notifications - var progressToken = (String) request.meta().get("progressToken"); - - return exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) - .then(exchange.progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) - .then(// Send a progress notification with another progress value - // should - exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", - 0.0, 1.0, "Another processing started"))) - .then(exchange.progressNotification( - new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) - .thenReturn(new CallToolResult(("Progress test completed"), false)); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try ( - // Create client with progress notification handler - var mcpClient = clientBuilder.progressConsumer(notification -> { - receivedNotifications.add(notification); - latch.countDown(); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that sends progress notifications - McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() - .name("progress-test") - .meta(Map.of("progressToken", "test-progress-token")) - .build(); - CallToolResult result = mcpClient.callTool(callToolRequest); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); - - assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - - // Should have received 3 notifications - assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.message(), n -> n)); - - // First notification should be 0.0/1.0 progress - assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); - assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); - - // Second notification should be 0.5/1.0 progress - assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); - assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); - - // Third notification should be another progress token with 0.0/1.0 progress - assertThat(notificationMap.get("Another processing started").progressToken()) - .isEqualTo("another-progress-token"); - assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); - assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Another processing started").message()) - .isEqualTo("Another processing started"); - - // Fourth notification should be 1.0/1.0 progress - assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); - } - finally { - mcpServer.close(); - } - } - - // --------------------------------------- - // Completion Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCompletionShouldReturnExpectedSuggestions(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - var expectedValues = List.of("python", "pytorch", "pyside"); - var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total - true // hasMore - )); - - AtomicReference samplingRequest = new AtomicReference<>(); - BiFunction completionHandler = (mcpSyncServerExchange, - request) -> { - samplingRequest.set(request); - return completionResponse; - }; - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().completions().build()) - .prompts(new McpServerFeatures.SyncPromptSpecification( - new Prompt("code_review", "Code review", "this is code review prompt", - List.of(new PromptArgument("language", "Language", "string", false))), - (mcpSyncServerExchange, getPromptRequest) -> null)) - .completions(new McpServerFeatures.SyncCompletionSpecification( - new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), - new CompleteRequest.CompleteArgument("language", "py")); - - CompleteResult result = mcpClient.completeCompletion(request); - - assertThat(result).isNotNull(); - - assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); - assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Ping Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testPingSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that uses ping functionality - AtomicReference executionOrder = new AtomicReference<>(""); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - executionOrder.set(executionOrder.get() + "1"); - - // Test async ping behavior - return exchange.ping().doOnNext(result -> { - - assertThat(result).isNotNull(); - // Ping should return an empty object or map - assertThat(result).isInstanceOf(Map.class); - - executionOrder.set(executionOrder.get() + "2"); - assertThat(result).isNotNull(); - }).then(Mono.fromCallable(() -> { - executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); - })); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that tests ping async behavior - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); - - // Verify execution order - assertThat(executionOrder.get()).isEqualTo("123"); - } - - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Tool Structured Output Schema Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputValidationSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of( - "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", - Map.of("type", "string"), "timestamp", Map.of("type", "string")), - "required", List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - String expression = (String) request.getOrDefault("expression", "2 + 3"); - double result = evaluateExpression(expression); - return CallToolResult.builder() - .structuredContent( - Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Verify tool is listed with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call tool with valid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputValidationFailure(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", - List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return invalid structured output. Result should be number, missing - // operation - return CallToolResult.builder() - .addTextContent("Invalid calculation") - .structuredContent(Map.of("result", "not-a-number", "extra", "field")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool with invalid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).contains("Validation failed"); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputMissingStructuredContent(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number")), "required", List.of("result")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return result without structured content but tool has output schema - return CallToolResult.builder().addTextContent("Calculation completed").build(); - }); + @BeforeEach + public void before() { - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build(); - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool that should return structured content but doesn't - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).isEqualTo( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - } + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - mcpServer.close(); + prepareClients(PORT, null); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputRuntimeToolAddition(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Start server without tools - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Initially no tools - assertThat(mcpClient.listTools().tools()).isEmpty(); - - // Add tool with output schema at runtime - Map outputSchema = Map.of("type", "object", "properties", - Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", - List.of("message", "count")); - - Tool dynamicTool = Tool.builder() - .name("dynamic-tool") - .description("Dynamically added tool") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool, - (exchange, request) -> { - int count = (Integer) request.getOrDefault("count", 1); - return CallToolResult.builder() - .addTextContent("Dynamic tool executed " + count + " times") - .structuredContent(Map.of("message", "Dynamic execution", "count", count)) - .build(); - }); - - // Add tool to server - mcpServer.addTool(toolSpec); - - // Wait for tool list change notification - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(mcpClient.listTools().tools()).hasSize(1); - }); - - // Verify tool was added with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call dynamically added tool - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) response.content().get(0)).text()) - .isEqualTo("Dynamic tool executed 3 times"); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"count":3,"message":"Dynamic execution"}""")); + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); } - - mcpServer.close(); - } - - private double evaluateExpression(String expression) { - // Simple expression evaluator for testing - return switch (expression) { - case "2 + 3" -> 5.0; - case "10 * 2" -> 20.0; - case "7 + 8" -> 15.0; - case "5 + 3" -> 8.0; - default -> 0.0; - }; } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java index 0327e6b53..302c58c5f 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java @@ -4,51 +4,29 @@ package io.modelcontextprotocol; +import java.time.Duration; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; + import com.fasterxml.jackson.databind.ObjectMapper; + import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; -import io.modelcontextprotocol.spec.McpSchema.CompleteResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptArgument; -import io.modelcontextprotocol.spec.McpSchema.PromptReference; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.server.McpTransportContext; -import net.javacrumbs.jsonunit.core.Option; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.client.RestClient; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.server.RouterFunctions; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; - -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.awaitility.Awaitility.await; - -class WebFluxStatelessIntegrationTests { +class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); @@ -58,19 +36,8 @@ class WebFluxStatelessIntegrationTests { private WebFluxStatelessServerTransport mcpStreamableServerTransport; - ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); - - @BeforeEach - public void before() { - this.mcpStreamableServerTransport = WebFluxStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .build(); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - + @Override + protected void prepareClients(int port, String mcpEndpoint) { clientBuilders .put("httpclient", McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) @@ -83,391 +50,37 @@ public void before() { .build()) .initializationTimeout(Duration.ofHours(10)) .requestTimeout(Duration.ofHours(10))); - - } - - @AfterEach - public void after() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolCallSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpStatelessServerFeatures.SyncToolSpecification tool1 = new McpStatelessServerFeatures.SyncToolSpecification( - new Tool("tool1", "tool1 description", emptyJsonSchema), (transportContext, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }); - - var mcpServer = McpServer.sync(mcpStreamableServerTransport) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testInitialize(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var mcpServer = McpServer.sync(mcpStreamableServerTransport).build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Completion Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCompletionShouldReturnExpectedSuggestions(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - var expectedValues = List.of("python", "pytorch", "pyside"); - var completionResponse = new CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total - true // hasMore - )); - - AtomicReference samplingRequest = new AtomicReference<>(); - BiFunction completionHandler = (transportContext, - request) -> { - samplingRequest.set(request); - return completionResponse; - }; - - var mcpServer = McpServer.sync(mcpStreamableServerTransport) - .capabilities(ServerCapabilities.builder().completions().build()) - .prompts(new McpStatelessServerFeatures.SyncPromptSpecification( - new Prompt("code_review", "Code review", "this is code review prompt", - List.of(new PromptArgument("language", "Language", "string", false))), - (transportContext, getPromptRequest) -> null)) - .completions(new McpStatelessServerFeatures.SyncCompletionSpecification( - new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), - new CompleteRequest.CompleteArgument("language", "py")); - - CompleteResult result = mcpClient.completeCompletion(request); - - assertThat(result).isNotNull(); - - assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); - assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); - } - - mcpServer.close(); } - // --------------------------------------- - // Tool Structured Output Schema Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputValidationSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of( - "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", - Map.of("type", "string"), "timestamp", Map.of("type", "string")), - "required", List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( - calculatorTool, (transportContext, request) -> { - String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); - double result = evaluateExpression(expression); - return CallToolResult.builder() - .structuredContent( - Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpStreamableServerTransport) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Verify tool is listed with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call tool with valid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - } - - mcpServer.close(); + @Override + protected StatelessAsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpStreamableServerTransport); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputValidationFailure(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", - List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( - calculatorTool, (transportContext, request) -> { - // Return invalid structured output. Result should be number, missing - // operation - return CallToolResult.builder() - .addTextContent("Invalid calculation") - .structuredContent(Map.of("result", "not-a-number", "extra", "field")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpStreamableServerTransport) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool with invalid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).contains("Validation failed"); - } - - mcpServer.close(); + @Override + protected StatelessSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpStreamableServerTransport); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputMissingStructuredContent(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number")), "required", List.of("result")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( - calculatorTool, (transportContext, request) -> { - // Return result without structured content but tool has output schema - return CallToolResult.builder().addTextContent("Calculation completed").build(); - }); - - var mcpServer = McpServer.sync(mcpStreamableServerTransport) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .instructions("bla") - .tools(tool) + @BeforeEach + public void before() { + this.mcpStreamableServerTransport = WebFluxStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .build(); - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool that should return structured content but doesn't - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).isEqualTo( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - } + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - mcpServer.close(); + prepareClients(PORT, null); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputRuntimeToolAddition(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Start server without tools - var mcpServer = McpServer.sync(mcpStreamableServerTransport) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Initially no tools - assertThat(mcpClient.listTools().tools()).isEmpty(); - - // Add tool with output schema at runtime - Map outputSchema = Map.of("type", "object", "properties", - Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", - List.of("message", "count")); - - Tool dynamicTool = Tool.builder() - .name("dynamic-tool") - .description("Dynamically added tool") - .outputSchema(outputSchema) - .build(); - - McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( - dynamicTool, (transportContext, request) -> { - int count = (Integer) request.arguments().getOrDefault("count", 1); - return CallToolResult.builder() - .addTextContent("Dynamic tool executed " + count + " times") - .structuredContent(Map.of("message", "Dynamic execution", "count", count)) - .build(); - }); - - // Add tool to server - mcpServer.addTool(toolSpec); - - // Wait for tool list change notification - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(mcpClient.listTools().tools()).hasSize(1); - }); - - // Verify tool was added with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call dynamically added tool - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) response.content().get(0)).text()) - .isEqualTo("Dynamic tool executed 3 times"); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"count":3,"message":"Dynamic execution"}""")); + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); } - - mcpServer.close(); - } - - private double evaluateExpression(String expression) { - // Simple expression evaluator for testing - return switch (expression) { - case "2 + 3" -> 5.0; - case "10 * 2" -> 20.0; - case "7 + 8" -> 15.0; - case "5 + 3" -> 8.0; - default -> 0.0; - }; } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java index 5cd19e627..c05570adf 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -4,70 +4,27 @@ package io.modelcontextprotocol; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; + import com.fasterxml.jackson.databind.ObjectMapper; + import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; -import io.modelcontextprotocol.spec.McpSchema.CompleteResult; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; -import io.modelcontextprotocol.spec.McpSchema.ElicitResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptArgument; -import io.modelcontextprotocol.spec.McpSchema.PromptReference; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import net.javacrumbs.jsonunit.core.Option; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.client.RestClient; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.server.RouterFunctions; -import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; -import reactor.test.StepVerifier; - -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.stream.Collectors; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; - -class WebFluxStreamableIntegrationTests { +class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); @@ -77,7 +34,29 @@ class WebFluxStreamableIntegrationTests { private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider; - ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build())); + clientBuilders.put("webflux", + McpClient.sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build())); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpStreamableServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpStreamableServerTransportProvider); + } @BeforeEach public void before() { @@ -92,19 +71,7 @@ public void before() { ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - clientBuilders - .put("httpclient", - McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .endpoint(CUSTOM_MESSAGE_ENDPOINT) - .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); - clientBuilders - .put("webflux", McpClient - .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .endpoint(CUSTOM_MESSAGE_ENDPOINT) - .build()) - .initializationTimeout(Duration.ofHours(10)) - .requestTimeout(Duration.ofHours(10))); - + prepareClients(PORT, null); } @AfterEach @@ -114,1380 +81,4 @@ public void after() { } } - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithoutSamplingCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) - .thenReturn(mock(CallToolResult.class))) - .build(); - - var server = McpServer.async(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build();) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var createMessageRequest = CreateMessageRequest.builder() - .messages(List - .of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(createMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - } - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - // Server - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = CreateMessageRequest.builder() - .messages(List - .of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(craeteMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) - .requestTimeout(Duration.ofSeconds(4)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - } - - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - // Server - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = CreateMessageRequest.builder() - .messages(List - .of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message")))) - .build(); - - return exchange.createMessage(craeteMessageRequest).thenReturn(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) - .requestTimeout(Duration.ofSeconds(1)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); - - } - - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithoutElicitationCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) - .then(Mono.just(mock(CallToolResult.class)))) - .build(); - - var server = McpServer.async(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithRequestTimeoutFail(String clientType) { - - var latch = new CountDownLatch(1); - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - try { - if (!latch.await(2, TimeUnit.SECONDS)) { - throw new RuntimeException("Timeout waiting for elicitation processing"); - } - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - - AtomicReference resultRef = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - return exchange.createElicitation(elicitationRequest) - .doOnNext(resultRef::set) - .then(Mono.just(callResponse)); - }) - .build(); - - var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) // 1 second. - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); - - ElicitResult elicitResult = resultRef.get(); - assertThat(elicitResult).isNull(); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithoutCapability(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }) - .build(); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> { - }) - .tools(tool) - .build(); - - // Create client without roots capability - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsNotificationWithEmptyRootsList(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithMultipleHandlers(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsServerCloseWithActiveSubscription(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolCallSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolListChangeHandlingSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new Tool("tool2", "tool2 description", emptyJsonSchema)) - .callHandler((exchange, request) -> callResponse) - .build(); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testInitialize(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testLoggingNotification(String clientType) throws InterruptedException { - int expectedNotificationsCount = 3; - CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("logging-test", "Test logging notifications", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - // Create and send notifications with different levels - - //@formatter:off - return exchange // This should be filtered out (DEBUG < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build()) - .then(exchange // This should be sent (NOTICE >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.NOTICE) - .logger("test-logger") - .data("Notice message") - .build())) - .then(exchange // This should be sent (ERROR > NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build())) - .then(exchange // This should be filtered out (INFO < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Another info message") - .build())) - .then(exchange // This should be sent (ERROR >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Another error message") - .build())) - .thenReturn(new CallToolResult("Logging test completed", false)); - //@formatter:on - }) - .build(); - - var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) - .build(); - - try ( - // Create client with logging notification handler - var mcpClient = clientBuilder.loggingConsumer(notification -> { - receivedNotifications.add(notification); - latch.countDown(); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Set minimum logging level to NOTICE - mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); - - // Call the tool that sends logging notifications - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - - assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); - - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - } - mcpServer.close(); - } - - // --------------------------------------- - // Progress Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testProgressNotification(String clientType) throws InterruptedException { - int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress - // token - CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder() - .name("progress-test") - .description("Test progress notifications") - .inputSchema(emptyJsonSchema) - .build()) - .callHandler((exchange, request) -> { - - // Create and send notifications - var progressToken = (String) request.meta().get("progressToken"); - - return exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) - .then(exchange.progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) - .then(// Send a progress notification with another progress value - // should - exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", - 0.0, 1.0, "Another processing started"))) - .then(exchange.progressNotification( - new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) - .thenReturn(new CallToolResult(("Progress test completed"), false)); - }) - .build(); - - var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try ( - // Create client with progress notification handler - var mcpClient = clientBuilder.progressConsumer(notification -> { - receivedNotifications.add(notification); - latch.countDown(); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that sends progress notifications - McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() - .name("progress-test") - .meta(Map.of("progressToken", "test-progress-token")) - .build(); - CallToolResult result = mcpClient.callTool(callToolRequest); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); - - assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - - // Should have received 3 notifications - assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.message(), n -> n)); - - // First notification should be 0.0/1.0 progress - assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); - assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); - - // Second notification should be 0.5/1.0 progress - assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); - assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); - - // Third notification should be another progress token with 0.0/1.0 progress - assertThat(notificationMap.get("Another processing started").progressToken()) - .isEqualTo("another-progress-token"); - assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); - assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Another processing started").message()) - .isEqualTo("Another processing started"); - - // Fourth notification should be 1.0/1.0 progress - assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); - } - finally { - mcpServer.close(); - } - } - - // --------------------------------------- - // Completion Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCompletionShouldReturnExpectedSuggestions(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - var expectedValues = List.of("python", "pytorch", "pyside"); - var completionResponse = new CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total - true // hasMore - )); - - AtomicReference samplingRequest = new AtomicReference<>(); - BiFunction completionHandler = (mcpSyncServerExchange, - request) -> { - samplingRequest.set(request); - return completionResponse; - }; - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .capabilities(ServerCapabilities.builder().completions().build()) - .prompts(new McpServerFeatures.SyncPromptSpecification( - new Prompt("code_review", "Code review", "this is code review prompt", - List.of(new PromptArgument("language", "Language", "string", false))), - (mcpSyncServerExchange, getPromptRequest) -> null)) - .completions(new McpServerFeatures.SyncCompletionSpecification( - new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), - new CompleteRequest.CompleteArgument("language", "py")); - - CompleteResult result = mcpClient.completeCompletion(request); - - assertThat(result).isNotNull(); - - assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); - assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Ping Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testPingSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that uses ping functionality - AtomicReference executionOrder = new AtomicReference<>(""); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - executionOrder.set(executionOrder.get() + "1"); - - // Test async ping behavior - return exchange.ping().doOnNext(result -> { - - assertThat(result).isNotNull(); - // Ping should return an empty object or map - assertThat(result).isInstanceOf(Map.class); - - executionOrder.set(executionOrder.get() + "2"); - assertThat(result).isNotNull(); - }).then(Mono.fromCallable(() -> { - executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); - })); - }) - .build(); - - var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that tests ping async behavior - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); - - // Verify execution order - assertThat(executionOrder.get()).isEqualTo("123"); - } - - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Tool Structured Output Schema Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputValidationSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of( - "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", - Map.of("type", "string"), "timestamp", Map.of("type", "string")), - "required", List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - String expression = (String) request.getOrDefault("expression", "2 + 3"); - double result = evaluateExpression(expression); - return CallToolResult.builder() - .structuredContent( - Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Verify tool is listed with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call tool with valid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputValidationFailure(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", - List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return invalid structured output. Result should be number, missing - // operation - return CallToolResult.builder() - .addTextContent("Invalid calculation") - .structuredContent(Map.of("result", "not-a-number", "extra", "field")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool with invalid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).contains("Validation failed"); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputMissingStructuredContent(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number")), "required", List.of("result")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return result without structured content but tool has output schema - return CallToolResult.builder().addTextContent("Calculation completed").build(); - }); - - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .instructions("bla") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool that should return structured content but doesn't - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).isEqualTo( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputRuntimeToolAddition(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Start server without tools - var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Initially no tools - assertThat(mcpClient.listTools().tools()).isEmpty(); - - // Add tool with output schema at runtime - Map outputSchema = Map.of("type", "object", "properties", - Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", - List.of("message", "count")); - - Tool dynamicTool = Tool.builder() - .name("dynamic-tool") - .description("Dynamically added tool") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool, - (exchange, request) -> { - int count = (Integer) request.getOrDefault("count", 1); - return CallToolResult.builder() - .addTextContent("Dynamic tool executed " + count + " times") - .structuredContent(Map.of("message", "Dynamic execution", "count", count)) - .build(); - }); - - // Add tool to server - mcpServer.addTool(toolSpec); - - // Wait for tool list change notification - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(mcpClient.listTools().tools()).hasSize(1); - }); - - // Verify tool was added with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call dynamically added tool - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) response.content().get(0)).text()) - .isEqualTo("Dynamic tool executed 3 times"); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"count":3,"message":"Dynamic execution"}""")); - } - - mcpServer.close(); - } - - private double evaluateExpression(String expression) { - // Simple expression evaluator for testing - return switch (expression) { - case "2 + 3" -> 5.0; - case "10 * 2" -> 20.0; - case "7 + 8" -> 15.0; - case "5 + 3" -> 8.0; - default -> 0.0; - }; - } - } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 45f6b94f0..071ed51b7 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -55,7 +55,10 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return WebMvcSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); } @Bean diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java index b2264ea00..93735d942 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java @@ -11,8 +11,6 @@ import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; @@ -29,7 +27,6 @@ import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; -import io.modelcontextprotocol.spec.McpSchema; import reactor.core.scheduler.Schedulers; class WebMvcStatelessIntegrationTests extends AbstractStatelessIntegrationTests { @@ -63,6 +60,31 @@ public RouterFunction routerFunction(WebMvcStatelessServerTransp private TomcatTestUtil.TomcatServer tomcatServer; + @Override + protected StatelessAsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransport); + } + + @Override + protected StatelessSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransport); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient.sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build())); + } + @BeforeEach public void before() { @@ -76,33 +98,13 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilders - .put("httpclient", - McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .endpoint(MESSAGE_ENDPOINT) - .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); - - clientBuilders.put("webflux", - McpClient.sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .endpoint(MESSAGE_ENDPOINT) - .build())); + prepareClients(PORT, MESSAGE_ENDPOINT); // Get the transport from Spring context this.mcpServerTransport = tomcatServer.appContext().getBean(WebMvcStatelessServerTransport.class); } - @Override - protected StatelessAsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(this.mcpServerTransport); - } - - @Override - protected StatelessSyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(this.mcpServerTransport); - } - @AfterEach public void after() { reactor.netty.http.HttpResources.disposeLoopsAndConnections(); @@ -124,42 +126,4 @@ public void after() { } } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void simple(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var server = McpServer.async(this.mcpServerTransport) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1000)) - .build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .requestTimeout(Duration.ofSeconds(1000)) - .build()) { - - assertThat(client.initialize()).isNotNull(); - - } - server.closeGracefully(); - } - - @Override - protected void prepareClients(int port, String mcpEndpoint) { - - clientBuilders.put("httpclient", McpClient - .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) - .initializationTimeout(Duration.ofHours(10)) - .requestTimeout(Duration.ofHours(10))); - - clientBuilders.put("webflux", - McpClient.sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + port)) - .endpoint(mcpEndpoint) - .build())); - } - } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index b3a699b94..ef6730a7c 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -19,10 +19,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.function.Function; +import java.util.stream.Collectors; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -31,16 +34,22 @@ import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; import io.modelcontextprotocol.spec.McpSchema.Role; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; @@ -740,7 +749,6 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ { "$schema": "http://json-schema.org/draft-07/schema#", @@ -944,6 +952,276 @@ void testInitialize(String clientType) { mcpServer.close(); } + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + ; + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("logging-test") + .description("Test logging notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(new CallToolResult("Logging test completed", false)); + //@formatter:on + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + mcpServer.close(); + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(new CallToolResult(("Progress test completed"), false)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference("ref/prompt", "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testPingSuccess(String clientType) { @@ -1006,7 +1284,6 @@ void testPingSuccess(String clientType) { // --------------------------------------- // Tool Structured Output Schema Tests // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testStructuredOutputValidationSuccess(String clientType) { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java index a53501898..28b353d32 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java @@ -19,10 +19,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.function.Function; +import java.util.stream.Collectors; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -32,6 +35,8 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; @@ -737,7 +742,6 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ { "$schema": "http://json-schema.org/draft-07/schema#", @@ -941,6 +945,276 @@ void testInitialize(String clientType) { mcpServer.close(); } + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + ; + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("logging-test") + .description("Test logging notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(new CallToolResult("Logging test completed", false)); + //@formatter:on + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + mcpServer.close(); + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(new CallToolResult(("Progress test completed"), false)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new McpSchema.Prompt("code_review", "Code review", "this is code review prompt", + List.of(new McpSchema.PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient" }) void testPingSuccess(String clientType) { @@ -1003,7 +1277,6 @@ void testPingSuccess(String clientType) { // --------------------------------------- // Tool Structured Output Schema Tests // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient" }) void testStructuredOutputValidationSuccess(String clientType) { From 56c2bdfdc96b7c143167df91e353c7cd25f7034b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Fri, 8 Aug 2025 19:47:57 +0100 Subject: [PATCH 2/3] increase request timeout for integration tests --- .../WebFluxSseIntegrationTests.java | 16 +++++++++----- .../WebFluxStreamableIntegrationTests.java | 21 ++++++++++++------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 1b82366f9..a1f1a8947 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol; +import java.time.Duration; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.springframework.http.server.reactive.HttpHandler; @@ -39,15 +41,19 @@ class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests @Override protected void prepareClients(int port, String mcpEndpoint) { - clientBuilders.put("httpclient", - McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build())); + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + clientBuilders.put("webflux", McpClient .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build())); + .build()) + .requestTimeout(Duration.ofHours(10))); + } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java index c05570adf..616c6dcf8 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol; +import java.time.Duration; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.springframework.http.server.reactive.HttpHandler; @@ -37,15 +39,18 @@ class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrati @Override protected void prepareClients(int port, String mcpEndpoint) { - clientBuilders.put("httpclient", - McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .endpoint(CUSTOM_MESSAGE_ENDPOINT) - .build())); + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); clientBuilders.put("webflux", - McpClient.sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .endpoint(CUSTOM_MESSAGE_ENDPOINT) - .build())); + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()) + .requestTimeout(Duration.ofHours(10))); } @Override From 5dab3b51f659c3f8660e361aee439daa1025736f Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Fri, 8 Aug 2025 20:06:24 +0100 Subject: [PATCH 3/3] Replace HttpServletSseServerTransportProviderIntegrationTests with HttpServletSseIntegrationTests extending the AbstractMcpClientServerIntegrationTests --- .../server/WebMvcSseIntegrationTests.java | 4 +- .../WebMvcStatelessIntegrationTests.java | 11 +- .../WebMvcStreamableIntegrationTests.java | 34 +- .../HttpServletSseIntegrationTests.java | 93 ++ ...HttpServletStreamableIntegrationTests.java | 2 +- ...rverTransportProviderIntegrationTests.java | 1390 ----------------- 6 files changed, 108 insertions(+), 1426 deletions(-) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 071ed51b7..995cbd165 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -42,11 +42,11 @@ protected void prepareClients(int port, String mcpEndpoint) { clientBuilders.put("httpclient", McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port).build()) - .initializationTimeout(Duration.ofHours(10)) .requestTimeout(Duration.ofHours(10))); clientBuilders.put("webflux", McpClient - .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + port)).build())); + .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + port)).build()) + .requestTimeout(Duration.ofHours(10))); } @Configuration diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java index 93735d942..802363d59 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java @@ -75,14 +75,15 @@ protected void prepareClients(int port, String mcpEndpoint) { clientBuilders.put("httpclient", McpClient .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) - .initializationTimeout(Duration.ofHours(10)) .requestTimeout(Duration.ofHours(10))); clientBuilders.put("webflux", - McpClient.sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + port)) - .endpoint(mcpEndpoint) - .build())); + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build()) + .requestTimeout(Duration.ofHours(10))); } @BeforeEach diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java index f99b016ff..84862f27e 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -124,42 +124,20 @@ public void after() { } } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void simple(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var server = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1000)) - .build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .requestTimeout(Duration.ofSeconds(1000)) - .build()) { - - assertThat(client.initialize()).isNotNull(); - - } - server.closeGracefully(); - } - @Override protected void prepareClients(int port, String mcpEndpoint) { clientBuilders.put("httpclient", McpClient .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) - .initializationTimeout(Duration.ofHours(10)) .requestTimeout(Duration.ofHours(10))); clientBuilders.put("webflux", - McpClient.sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + port)) - .endpoint(mcpEndpoint) - .build())); + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build()) + .requestTimeout(Duration.ofHours(10))); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java new file mode 100644 index 000000000..56e74218f --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -0,0 +1,93 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; + +class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java index 07c6e7c5c..6ac10014e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -55,7 +55,7 @@ public void before() { .put("httpclient", McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) .endpoint(MESSAGE_ENDPOINT) - .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + .build()).requestTimeout(Duration.ofHours(10))); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java deleted file mode 100644 index bf38e68ec..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ /dev/null @@ -1,1390 +0,0 @@ -/* - * Copyright 2024 - 2025 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; -import java.util.stream.Collectors; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; -import io.modelcontextprotocol.spec.McpSchema.ElicitResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import net.javacrumbs.jsonunit.core.Option; - -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import org.springframework.web.client.RestClient; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.InstanceOfAssertFactories.type; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; - -class HttpServletSseServerTransportProviderIntegrationTests { - - private static final int PORT = TomcatTestUtil.findAvailablePort(); - - private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; - - private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; - - private HttpServletSseServerTransportProvider mcpServerTransportProvider; - - McpClient.SyncSpec clientBuilder; - - private Tomcat tomcat; - - @BeforeEach - public void before() { - // Create and configure the transport provider - mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build(); - - tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); - try { - tomcat.start(); - assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build()); - } - - @AfterEach - public void after() { - if (mcpServerTransportProvider != null) { - mcpServerTransportProvider.closeGracefully().block(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - // @Disabled - void testCreateMessageWithoutSamplingCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @Test - void testCreateMessageSuccess() { - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.close(); - mcpServer.close(); - } - - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @Test - // @Disabled - void testCreateElicitationWithoutElicitationCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.createElicitation(mock(ElicitRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @Test - void testCreateElicitationSuccess() { - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutSuccess() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutFail() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - - mcpServer.close(); - } - } - - @Test - void testRootsWithoutCapability() { - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @Test - void testRootsNotificationWithEmptyRootsList() { - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsWithMultipleHandlers() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsServerCloseWithActiveSubscription() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - assertThat(McpTestServletFilter.getThreadLocalValue()).as("blocking code exectuion should be offloaded") - .isNull(); - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @Test - void testToolCallImmediateExecution() { - McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var threadLocalValue = McpTestServletFilter.getThreadLocalValue(); - return CallToolResult.builder() - .addTextContent(threadLocalValue != null ? threadLocalValue : "") - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .immediateExecution(true) - .build(); - - try (var mcpClient = clientBuilder.build()) { - mcpClient.initialize(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response.content()).first() - .asInstanceOf(type(McpSchema.TextContent.class)) - .extracting(McpSchema.TextContent::text) - .isEqualTo(McpTestServletFilter.THREAD_LOCAL_VALUE); - } - - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema)) - .callHandler((exchange, request) -> callResponse) - .build(); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @Test - void testInitialize() { - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - @Test - void testLoggingNotification() { - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - // Create and send notifications with different levels - - // This should be filtered out (DEBUG < NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build()) - .block(); - - // This should be sent (NOTICE >= NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.NOTICE) - .logger("test-logger") - .data("Notice message") - .build()) - .block(); - - // This should be sent (ERROR > NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build()) - .block(); - - // This should be filtered out (INFO < NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Another info message") - .build()) - .block(); - - // This should be sent (ERROR >= NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Another error message") - .build()) - .block(); - - return Mono.just(new CallToolResult("Logging test completed", false)); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) - .build(); - try ( - // Create client with logging notification handler - var mcpClient = clientBuilder.loggingConsumer(notification -> { - receivedNotifications.add(notification); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Set minimum logging level to NOTICE - mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); - - // Call the tool that sends logging notifications - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - - // Wait for notifications to be processed - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - - System.out.println("Received notifications: " + receivedNotifications); - - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(3); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); - - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()) - .isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - }); - } - mcpServer.close(); - } - - // --------------------------------------- - // Progress Tests - // --------------------------------------- - @Test - void testProgressNotification() { - // Create a list to store received progress notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - // Create server with a tool that sends progress notifications - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - McpSchema.Tool.builder() - .name("progress-test") - .description("Test progress notifications") - .inputSchema(emptyJsonSchema) - .build(), - null, (exchange, request) -> { - - var progressToken = request.progressToken(); - - exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.1, 1.0, "Test progress 1/10")) - .block(); - - exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Test progress 5/10")) - .block(); - - exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Test progress 10/10")) - .block(); - - return Mono.just(new CallToolResult("Progress test completed", false)); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) - .build(); - - // Create client with progress notification handler - try (var mcpClient = clientBuilder.progressConsumer(receivedNotifications::add).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that sends progress notifications - CallToolResult result = mcpClient.callTool( - new McpSchema.CallToolRequest("progress-test", Map.of(), Map.of("progressToken", "test-token"))); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); - - // Wait for notifications to be processed - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - // Should have received 3 notifications - assertThat(receivedNotifications).hasSize(3); - - // Check the progress notifications - assertThat(receivedNotifications.stream().map(McpSchema.ProgressNotification::progressToken)) - .containsExactlyInAnyOrder("test-token", "test-token", "test-token"); - assertThat(receivedNotifications.stream().map(McpSchema.ProgressNotification::progress)) - .containsExactlyInAnyOrder(0.1, 0.5, 1.0); - }); - } - finally { - mcpServer.close(); - } - } - - // --------------------------------------- - // Ping Tests - // --------------------------------------- - @Test - void testPingSuccess() { - // Create server with a tool that uses ping functionality - AtomicReference executionOrder = new AtomicReference<>(""); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema), - (exchange, request) -> { - - executionOrder.set(executionOrder.get() + "1"); - - // Test async ping behavior - return exchange.ping().doOnNext(result -> { - - assertThat(result).isNotNull(); - // Ping should return an empty object or map - assertThat(result).isInstanceOf(Map.class); - - executionOrder.set(executionOrder.get() + "2"); - assertThat(result).isNotNull(); - }).then(Mono.fromCallable(() -> { - executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); - })); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that tests ping async behavior - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); - - // Verify execution order - assertThat(executionOrder.get()).isEqualTo("123"); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tool Structured Output Schema Tests - // --------------------------------------- - @Test - void testStructuredOutputValidationSuccess() { - // Create a tool with output schema - Map outputSchema = Map.of( - "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", - Map.of("type", "string"), "timestamp", Map.of("type", "string")), - "required", List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - String expression = (String) request.getOrDefault("expression", "2 + 3"); - double result = evaluateExpression(expression); - return CallToolResult.builder() - .structuredContent( - Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Verify tool is listed with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call tool with valid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - - // Verify structured content (may be null in sync server but validation still - // works) - if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) - .containsEntry("operation", "2 + 3") - .containsEntry("timestamp", "2024-01-01T10:00:00Z"); - } - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputValidationFailure() { - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", - List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return invalid structured output. Result should be number, missing - // operation - return CallToolResult.builder() - .addTextContent("Invalid calculation") - .structuredContent(Map.of("result", "not-a-number", "extra", "field")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool with invalid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).contains("Validation failed"); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputMissingStructuredContent() { - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number")), "required", List.of("result")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return result without structured content but tool has output schema - return CallToolResult.builder().addTextContent("Calculation completed").build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool that should return structured content but doesn't - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).isEqualTo( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputRuntimeToolAddition() { - // Start server without tools - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Initially no tools - assertThat(mcpClient.listTools().tools()).isEmpty(); - - // Add tool with output schema at runtime - Map outputSchema = Map.of("type", "object", "properties", - Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", - List.of("message", "count")); - - Tool dynamicTool = Tool.builder() - .name("dynamic-tool") - .description("Dynamically added tool") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool, - (exchange, request) -> { - int count = (Integer) request.getOrDefault("count", 1); - return CallToolResult.builder() - .addTextContent("Dynamic tool executed " + count + " times") - .structuredContent(Map.of("message", "Dynamic execution", "count", count)) - .build(); - }); - - // Add tool to server - mcpServer.addTool(toolSpec); - - // Wait for tool list change notification - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(mcpClient.listTools().tools()).hasSize(1); - }); - - // Verify tool was added with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call dynamically added tool - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.structuredContent()).containsEntry("message", "Dynamic execution") - .containsEntry("count", 3); - } - - mcpServer.close(); - } - - private double evaluateExpression(String expression) { - // Simple expression evaluator for testing - return switch (expression) { - case "2 + 3" -> 5.0; - case "10 * 2" -> 20.0; - case "7 + 8" -> 15.0; - case "5 + 3" -> 8.0; - default -> 0.0; - }; - } - -}