diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java index d54a5bd43..cd6e8950f 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java @@ -46,6 +46,7 @@ import io.modelcontextprotocol.util.Utils; import net.javacrumbs.jsonunit.core.Option; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -70,7 +71,7 @@ public abstract class AbstractMcpClientServerIntegrationTests { abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void simple(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -78,7 +79,6 @@ void simple(String clientType) { var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .requestTimeout(Duration.ofSeconds(1000)) .build(); - try ( // Create client without sampling capabilities var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) @@ -97,7 +97,7 @@ void simple(String clientType) { // Sampling Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateMessageWithoutSamplingCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -133,7 +133,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateMessageSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -202,7 +202,7 @@ void testCreateMessageSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { // Client @@ -282,7 +282,7 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { var clientBuilder = clientBuilders.get(clientType); @@ -348,7 +348,7 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt // Elicitation Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithoutElicitationCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -380,7 +380,7 @@ void testCreateElicitationWithoutElicitationCapabilities(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateElicitationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -437,7 +437,7 @@ void testCreateElicitationSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -498,7 +498,7 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutFail(String clientType) { var latch = new CountDownLatch(1); @@ -569,7 +569,7 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) { // Roots Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -617,7 +617,7 @@ void testRootsSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsWithoutCapability(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -656,7 +656,7 @@ void testRootsWithoutCapability(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsNotificationWithEmptyRootsList(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -686,7 +686,7 @@ void testRootsNotificationWithEmptyRootsList(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsWithMultipleHandlers(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -720,7 +720,7 @@ void testRootsWithMultipleHandlers(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testRootsServerCloseWithActiveSubscription(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -755,7 +755,7 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { // Tools Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -806,7 +806,7 @@ void testToolCallSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -844,7 +844,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testToolCallSuccessWithTranportContextExtraction(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -901,7 +901,7 @@ void testToolCallSuccessWithTranportContextExtraction(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -994,7 +994,7 @@ void testToolListChangeHandlingSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testInitialize(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1015,7 +1015,7 @@ void testInitialize(String clientType) { // Logging Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testLoggingNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 3; CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); @@ -1128,7 +1128,7 @@ void testLoggingNotification(String clientType) throws InterruptedException { // Progress Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testProgressNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress // token @@ -1234,7 +1234,7 @@ void testProgressNotification(String clientType) throws InterruptedException { // Completion Tests // --------------------------------------- @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testCompletionShouldReturnExpectedSuggestions(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1256,7 +1256,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { List.of(new PromptArgument("language", "Language", "string", false))), (mcpSyncServerExchange, getPromptRequest) -> null)) .completions(new McpServerFeatures.SyncCompletionSpecification( - new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) .build(); try (var mcpClient = clientBuilder.build()) { @@ -1285,7 +1285,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { // Ping Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testPingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1348,7 +1348,7 @@ void testPingSuccess(String clientType) { // Tool Structured Output Schema Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1593,7 +1593,7 @@ void testStructuredOutputValidationFailure(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testStructuredOutputMissingStructuredContent(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1644,7 +1644,7 @@ void testStructuredOutputMissingStructuredContent(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) + @MethodSource("clientsForTesting") void testStructuredOutputRuntimeToolAddition(String clientType) { var clientBuilder = clientBuilders.get(clientType); diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java index fd05b593b..d2b9d14d0 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.Map; +import java.util.stream.Stream; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; @@ -21,6 +22,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; import static org.assertj.core.api.Assertions.assertThat; @@ -37,6 +39,10 @@ class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationT private Tomcat tomcat; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + @BeforeEach public void before() { // Create and configure the transport provider diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java index 223c78a94..81423e0c5 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.Map; +import java.util.stream.Stream; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; @@ -21,6 +22,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; import static org.assertj.core.api.Assertions.assertThat; @@ -35,6 +37,10 @@ class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerInteg private Tomcat tomcat; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + @BeforeEach public void before() { // Create and configure the transport provider diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index f580b59e8..eb8abb90c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -6,10 +6,13 @@ import java.time.Duration; import java.util.Map; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; @@ -45,6 +48,10 @@ class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext .create(Map.of("important", "value")); + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Override protected void prepareClients(int port, String mcpEndpoint) { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java index a00e24b55..96a786a9e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java @@ -5,10 +5,13 @@ package io.modelcontextprotocol; import java.time.Duration; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; @@ -35,6 +38,10 @@ class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests private WebFluxStatelessServerTransport mcpStreamableServerTransport; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Override protected void prepareClients(int port, String mcpEndpoint) { clientBuilders diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java index e4bcef829..5ab651931 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -6,10 +6,13 @@ import java.time.Duration; import java.util.Map; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; @@ -43,6 +46,10 @@ class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrati static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext .create(Map.of("important", "value")); + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Override protected void prepareClients(int port, String mcpEndpoint) { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 3dacb62d8..1150e47f5 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -31,7 +31,7 @@ import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.reactive.function.client.WebClient; -import static io.modelcontextprotocol.utils.McpJsonMapperUtils.JSON_MAPPER; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java index ae1f4f4d1..36aaa27fb 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java @@ -8,6 +8,7 @@ import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.Timeout; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; @@ -21,7 +22,7 @@ import reactor.netty.DisposableServer; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * Tests for {@link McpAsyncServer} using {@link WebMvcSseServerTransportProvider}. * * @author Christian Tzolov */ diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java index c8c24b8a7..2f75551eb 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java @@ -21,7 +21,7 @@ import reactor.netty.DisposableServer; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * Tests for {@link McpAsyncServer} using {@link WebMvcSseServerTransportProvider}. * * @author Christian Tzolov */ diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index e780b8e51..045f9b3dd 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -7,12 +7,15 @@ import java.time.Duration; import java.util.Map; +import java.util.stream.Stream; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; @@ -43,6 +46,10 @@ class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext .create(Map.of("important", "value")); + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Override protected void prepareClients(int port, String mcpEndpoint) { diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java index 9633dfbd1..8c7b0a85e 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java @@ -6,12 +6,15 @@ import static org.assertj.core.api.Assertions.assertThat; import java.time.Duration; +import java.util.stream.Stream; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; @@ -37,6 +40,10 @@ class WebMvcStatelessIntegrationTests extends AbstractStatelessIntegrationTests private WebMvcStatelessServerTransport mcpServerTransport; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Configuration @EnableWebMvc static class TestConfig { diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java index abdd82967..cb7b4a2a0 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -7,12 +7,15 @@ import java.time.Duration; import java.util.Map; +import java.util.stream.Stream; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; @@ -43,6 +46,10 @@ class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegratio static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = r -> McpTransportContext .create(Map.of("important", "value")); + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); + } + @Configuration @EnableWebMvc static class TestConfig { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index a36d9006a..84bd271a5 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -50,11 +50,12 @@ import io.modelcontextprotocol.util.Utils; import net.javacrumbs.jsonunit.core.Option; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import static io.modelcontextprotocol.utils.ToolsUtils.EMPTY_JSON_SCHEMA; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; @@ -74,7 +75,7 @@ public abstract class AbstractMcpClientServerIntegrationTests { abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void simple(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -100,7 +101,7 @@ void simple(String clientType) { // Sampling Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageWithoutSamplingCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -136,7 +137,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -205,7 +206,7 @@ void testCreateMessageSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { // Client @@ -285,7 +286,7 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { var clientBuilder = clientBuilders.get(clientType); @@ -351,7 +352,7 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt // Elicitation Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithoutElicitationCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -383,7 +384,7 @@ void testCreateElicitationWithoutElicitationCapabilities(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -440,7 +441,7 @@ void testCreateElicitationSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -501,7 +502,7 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutFail(String clientType) { var latch = new CountDownLatch(1); @@ -572,7 +573,7 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) { // Roots Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -620,7 +621,7 @@ void testRootsSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsWithoutCapability(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -659,7 +660,7 @@ void testRootsWithoutCapability(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsNotificationWithEmptyRootsList(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -689,7 +690,7 @@ void testRootsNotificationWithEmptyRootsList(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsWithMultipleHandlers(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -723,7 +724,7 @@ void testRootsWithMultipleHandlers(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsServerCloseWithActiveSubscription(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -758,7 +759,7 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { // Tools Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -809,7 +810,7 @@ void testToolCallSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -847,7 +848,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolCallSuccessWithTranportContextExtraction(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -904,7 +905,7 @@ void testToolCallSuccessWithTranportContextExtraction(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -997,7 +998,7 @@ void testToolListChangeHandlingSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testInitialize(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1018,7 +1019,7 @@ void testInitialize(String clientType) { // Logging Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testLoggingNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 3; CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); @@ -1131,7 +1132,7 @@ void testLoggingNotification(String clientType) throws InterruptedException { // Progress Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testProgressNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress // token @@ -1237,7 +1238,7 @@ void testProgressNotification(String clientType) throws InterruptedException { // Completion Tests // --------------------------------------- @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCompletionShouldReturnExpectedSuggestions(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1288,7 +1289,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { // Ping Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testPingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1351,7 +1352,7 @@ void testPingSuccess(String clientType) { // Tool Structured Output Schema Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1596,7 +1597,7 @@ void testStructuredOutputValidationFailure(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputMissingStructuredContent(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1647,7 +1648,7 @@ void testStructuredOutputMissingStructuredContent(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputRuntimeToolAddition(String clientType) { var clientBuilder = clientBuilders.get(clientType); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java index 705535e93..240732ebe 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java @@ -28,10 +28,11 @@ import io.modelcontextprotocol.spec.McpSchema.Tool; import net.javacrumbs.jsonunit.core.Option; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; -import static io.modelcontextprotocol.utils.ToolsUtils.EMPTY_JSON_SCHEMA; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; @@ -49,7 +50,7 @@ public abstract class AbstractStatelessIntegrationTests { abstract protected StatelessSyncSpecification prepareSyncServerBuilder(); @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void simple(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -76,7 +77,7 @@ void simple(String clientType) { // Tools Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -126,7 +127,7 @@ void testToolCallSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -164,7 +165,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -246,7 +247,7 @@ void testToolListChangeHandlingSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testInitialize(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -267,7 +268,7 @@ void testInitialize(String clientType) { // Tool Structured Output Schema Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -342,7 +343,7 @@ void testStructuredOutputValidationSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -402,7 +403,7 @@ void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputWithInHandlerError(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -460,7 +461,7 @@ void testStructuredOutputWithInHandlerError(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationFailure(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -516,7 +517,7 @@ void testStructuredOutputValidationFailure(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputMissingStructuredContent(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -567,7 +568,7 @@ void testStructuredOutputMissingStructuredContent(String clientType) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputRuntimeToolAddition(String clientType) { var clientBuilder = clientBuilders.get(clientType); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 8a0b3e0d9..e1b051204 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.client; -import static io.modelcontextprotocol.utils.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -23,7 +22,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.json.McpJsonMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -54,6 +52,8 @@ import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Test suite for the {@link McpAsyncClient} that can be used with different * {@link McpTransport} implementations. diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index b0701911a..ed7f2c3ce 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -26,7 +26,7 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import static io.modelcontextprotocol.utils.ToolsUtils.EMPTY_JSON_SCHEMA; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index d804de43b..d7b1dab2a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import static io.modelcontextprotocol.utils.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -15,6 +14,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; + import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java similarity index 84% rename from mcp-test/src/main/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java rename to mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java index e9ec8900c..723965519 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/utils/McpJsonMapperUtils.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java @@ -1,4 +1,4 @@ -package io.modelcontextprotocol.utils; +package io.modelcontextprotocol.util; import io.modelcontextprotocol.json.McpJsonMapper; diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/utils/ToolsUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java similarity index 88% rename from mcp-test/src/main/java/io/modelcontextprotocol/utils/ToolsUtils.java rename to mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java index ec603aac1..ce8755223 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/utils/ToolsUtils.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -1,4 +1,4 @@ -package io.modelcontextprotocol.utils; +package io.modelcontextprotocol.util; import io.modelcontextprotocol.spec.McpSchema;