Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import io.modelcontextprotocol.util.Utils;
import net.javacrumbs.jsonunit.core.Option;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
Expand All @@ -70,15 +71,14 @@ public abstract class AbstractMcpClientServerIntegrationTests {
abstract protected McpServer.SyncSpecification<?> prepareSyncServerBuilder();

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void simple(String clientType) {

var clientBuilder = clientBuilders.get(clientType);

var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0")
.requestTimeout(Duration.ofSeconds(1000))
.build();

try (
// Create client without sampling capabilities
var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
Expand All @@ -97,7 +97,7 @@ void simple(String clientType) {
// Sampling Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testCreateMessageWithoutSamplingCapabilities(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -133,7 +133,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testCreateMessageSuccess(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -202,7 +202,7 @@ void testCreateMessageSuccess(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException {

// Client
Expand Down Expand Up @@ -282,7 +282,7 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -348,7 +348,7 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt
// Elicitation Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testCreateElicitationWithoutElicitationCapabilities(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -380,7 +380,7 @@ void testCreateElicitationWithoutElicitationCapabilities(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testCreateElicitationSuccess(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -437,7 +437,7 @@ void testCreateElicitationSuccess(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testCreateElicitationWithRequestTimeoutSuccess(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -498,7 +498,7 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testCreateElicitationWithRequestTimeoutFail(String clientType) {

var latch = new CountDownLatch(1);
Expand Down Expand Up @@ -569,7 +569,7 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) {
// Roots Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testRootsSuccess(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

Expand Down Expand Up @@ -617,7 +617,7 @@ void testRootsSuccess(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testRootsWithoutCapability(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -656,7 +656,7 @@ void testRootsWithoutCapability(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testRootsNotificationWithEmptyRootsList(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -686,7 +686,7 @@ void testRootsNotificationWithEmptyRootsList(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testRootsWithMultipleHandlers(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -720,7 +720,7 @@ void testRootsWithMultipleHandlers(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testRootsServerCloseWithActiveSubscription(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -755,7 +755,7 @@ void testRootsServerCloseWithActiveSubscription(String clientType) {
// Tools Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testToolCallSuccess(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -806,7 +806,7 @@ void testToolCallSuccess(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -844,7 +844,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testToolCallSuccessWithTranportContextExtraction(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -901,7 +901,7 @@ void testToolCallSuccessWithTranportContextExtraction(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testToolListChangeHandlingSuccess(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -994,7 +994,7 @@ void testToolListChangeHandlingSuccess(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testInitialize(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand All @@ -1015,7 +1015,7 @@ void testInitialize(String clientType) {
// Logging Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testLoggingNotification(String clientType) throws InterruptedException {
int expectedNotificationsCount = 3;
CountDownLatch latch = new CountDownLatch(expectedNotificationsCount);
Expand Down Expand Up @@ -1128,7 +1128,7 @@ void testLoggingNotification(String clientType) throws InterruptedException {
// Progress Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testProgressNotification(String clientType) throws InterruptedException {
int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress
// token
Expand Down Expand Up @@ -1234,7 +1234,7 @@ void testProgressNotification(String clientType) throws InterruptedException {
// Completion Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : Completion call")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testCompletionShouldReturnExpectedSuggestions(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

Expand All @@ -1256,7 +1256,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
List.of(new PromptArgument("language", "Language", "string", false))),
(mcpSyncServerExchange, getPromptRequest) -> null))
.completions(new McpServerFeatures.SyncCompletionSpecification(
new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler))
new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler))
.build();

try (var mcpClient = clientBuilder.build()) {
Expand Down Expand Up @@ -1285,7 +1285,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
// Ping Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testPingSuccess(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -1348,7 +1348,7 @@ void testPingSuccess(String clientType) {
// Tool Structured Output Schema Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testStructuredOutputValidationSuccess(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

Expand Down Expand Up @@ -1593,7 +1593,7 @@ void testStructuredOutputValidationFailure(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testStructuredOutputMissingStructuredContent(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down Expand Up @@ -1644,7 +1644,7 @@ void testStructuredOutputMissingStructuredContent(String clientType) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
@MethodSource("clientsForTesting")
void testStructuredOutputRuntimeToolAddition(String clientType) {

var clientBuilder = clientBuilders.get(clientType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.time.Duration;
import java.util.Map;
import java.util.stream.Stream;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
Expand All @@ -21,6 +22,7 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.provider.Arguments;

import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -37,6 +39,10 @@ class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationT

private Tomcat tomcat;

static Stream<Arguments> clientsForTesting() {
return Stream.of(Arguments.of("httpclient"));
}

@BeforeEach
public void before() {
// Create and configure the transport provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.time.Duration;
import java.util.Map;
import java.util.stream.Stream;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
Expand All @@ -21,6 +22,7 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.provider.Arguments;

import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -35,6 +37,10 @@ class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerInteg

private Tomcat tomcat;

static Stream<Arguments> clientsForTesting() {
return Stream.of(Arguments.of("httpclient"));
}

@BeforeEach
public void before() {
// Create and configure the transport provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import java.time.Duration;
import java.util.Map;
import java.util.stream.Stream;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.provider.Arguments;

import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.client.WebClient;
Expand Down Expand Up @@ -45,6 +48,10 @@ class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
.create(Map.of("important", "value"));

static Stream<Arguments> clientsForTesting() {
return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux"));
}

@Override
protected void prepareClients(int port, String mcpEndpoint) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
package io.modelcontextprotocol;

import java.time.Duration;
import java.util.stream.Stream;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.provider.Arguments;

import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.client.WebClient;
Expand All @@ -35,6 +38,10 @@ class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests

private WebFluxStatelessServerTransport mcpStreamableServerTransport;

static Stream<Arguments> clientsForTesting() {
return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux"));
}

@Override
protected void prepareClients(int port, String mcpEndpoint) {
clientBuilders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import java.time.Duration;
import java.util.Map;
import java.util.stream.Stream;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.provider.Arguments;

import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.client.WebClient;
Expand Down Expand Up @@ -43,6 +46,10 @@ class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrati
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
.create(Map.of("important", "value"));

static Stream<Arguments> clientsForTesting() {
return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux"));
}

@Override
protected void prepareClients(int port, String mcpEndpoint) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.client.WebClient;

import static io.modelcontextprotocol.utils.McpJsonMapperUtils.JSON_MAPPER;
import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down
Loading