diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index a8a59a63a..033139adc 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -6,7 +6,10 @@ import java.time.Duration; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; @@ -44,10 +47,6 @@ */ public abstract class AbstractMcpAsyncClientTests { - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; abstract protected ClientMcpTransport createMcpTransport(); @@ -66,25 +65,47 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpAsyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) + McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @BeforeEach + void setUp() { + onStart(); } @AfterEach void tearDown() { - if (mcpAsyncClient != null) { - StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); - } onClose(); } @@ -93,258 +114,323 @@ void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); - }).verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listTools(null)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing tools")) + .verify(); + }); } @Test void testListTools() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) - .consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }) - .verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); } @Test void testPingWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.ping()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before pinging the server")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.ping()) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before pinging the " + "server")) + .verify(); + }); } @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); } @Test void testCallToolWithoutInitialization() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before calling tools")) - .verify(); + StepVerifier.withVirtualTime(() -> mcpAsyncClient.callTool(callToolRequest)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before calling tools")) + .verify(); + }); } @Test void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) - .consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull(); - assertThat(callToolResult.content()).isNotNull(); - assertThat(callToolResult.isError()).isNull(); - }) - .verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) - .expectError(Exception.class) - .verify(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); } @Test void testListResourcesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResources(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resources")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listResources(null)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing resources")) + .verify(); + }); } @Test void testListResources() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) - .consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); } @Test void testListPromptsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listPrompts(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing prompts")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listPrompts(null)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing prompts")) + .verify(); + }); } @Test void testListPrompts() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) - .consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testGetPromptWithoutInitialization() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); + withClient(createMcpTransport(), mcpAsyncClient -> { + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - StepVerifier.create(mcpAsyncClient.getPrompt(request)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before getting prompts")) - .verify(); + StepVerifier.withVirtualTime(() -> mcpAsyncClient.getPrompt(request)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before getting prompts")) + .verify(); + }); } @Test void testGetPrompt() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); } @Test void testRootsListChangedWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) - .expectErrorMatches(error -> error instanceof McpError && error.getMessage() - .equals("Client must be initialized before sending roots list changed notification")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.rootsListChangedNotification()) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before sending roots list changed notification")) + .verify(); + }); } @Test void testRootsListChanged() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - - StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); } @Test void testAddRootWithNullValue() { - StepVerifier.create(mcpAsyncClient.addRoot(null)) - .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .verify(); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); - StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); } @Test void testRemoveNonExistentRoot() { - StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); } @Test @Disabled void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + }).verifyComplete(); + } + }).verifyComplete(); + }); } @Test void testListResourceTemplatesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResourceTemplates()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resource templates")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listResourceTemplates()) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing resource templates")) + .verify(); + }); } @Test void testListResourceTemplates() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) - .consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); } // @Test void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); } @Test @@ -353,36 +439,44 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + + var transport = createMcpTransport(); + var client = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer( + prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) + .build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - var capabilities = ClientCapabilities.builder() .experimental(Map.of("feature", "test")) .roots(true) @@ -391,18 +485,14 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - StepVerifier.create(client.initialize()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - }).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + client -> - StepVerifier.create(client.closeGracefully()).verifyComplete(); + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); } // --------------------------------------- @@ -411,43 +501,52 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before setting logging level")) - .verify(); + withClient(createMcpTransport(), + mcpAsyncClient -> StepVerifier + .withVirtualTime(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before setting logging level")) + .verify()); } @Test void testLoggingLevels() { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); + withClient(createMcpTransport(), mcpAsyncClient -> { + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier.create(testAllLevels).verifyComplete(); + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) - .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 0f83e31e3..032f8684a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -7,6 +7,9 @@ import java.time.Duration; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; @@ -27,6 +30,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; @@ -40,12 +47,8 @@ */ public abstract class AbstractMcpSyncClientTests { - private McpSyncClient mcpSyncClient; - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - protected ClientMcpTransport mcpTransport; - abstract protected ClientMcpTransport createMcpTransport(); protected void onStart() { @@ -62,254 +65,322 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpSyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) + McpClient.SyncSpec builder = McpClient.sync(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); + } + } + + @BeforeEach + void setUp() { + onStart(); + } @AfterEach void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); - } onClose(); } + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationTimesOut(Consumer operation, String action) { + verifyCallTimesOut(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallTimesOut(Function operation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + // This scheduler is not replaced by virtual time scheduler + Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); + + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> operation.apply(mcpSyncClient)) + // offload the blocking call to the real scheduler + .subscribeOn(customScheduler)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + + customScheduler.dispose(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listTools(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + verifyCallTimesOut(client -> client.listTools(null), "listing tools"); } @Test void testListTools() { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(null); - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); }); } @Test void testCallToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)))) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), + "calling tools"); } @Test void testCallTools() { - mcpSyncClient.initialize(); - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - assertThat(toolResult).isNotNull().satisfies(result -> { + assertThat(toolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).hasSize(1); + assertThat(result.content()).hasSize(1); - TextContent content = (TextContent) result.content().get(0); + TextContent content = (TextContent) result.content().get(0); - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); }); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.ping()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + verifyCallTimesOut(client -> client.ping(), "pinging the server"); } @Test void testPing() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(callToolRequest)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - mcpSyncClient.initialize(); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); } @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.rootsListChangedNotification()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResources(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + verifyCallTimesOut(client -> client.listResources(null), "listing resources"); } @Test void testListResources() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); }); } @Test void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); } @Test void testReadResourceWithoutInitialization() { - assertThatThrownBy(() -> { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - mcpSyncClient.readResource(resource); - }).isInstanceOf(McpError.class).hasMessage("Client must be initialized before reading resources"); + Resource resource = new Resource("test://uri", "Test Resource", null, null, null); + verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); } @Test void testReadResource() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + ReadResourceResult result = mcpSyncClient.readResource(firstResource); - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + } + }); } @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResourceTemplates(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); } @Test void testListResourceTemplates() { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); } // @Test void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); } @Test @@ -318,18 +389,17 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + client -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } // --------------------------------------- @@ -338,40 +408,37 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test void testLoggingLevels() { - mcpSyncClient.initialize(); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 614c65125..d35db3f89 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -353,14 +353,15 @@ public Mono closeGracefully() { // Give a short time for any pending messages to be processed return Mono.delay(Duration.ofMillis(100)); - })).then(Mono.fromFuture(() -> { + })).then(Mono.defer(() -> { logger.debug("Sending TERM to process"); if (this.process != null) { this.process.destroy(); - return process.onExit(); + return Mono.fromFuture(process.onExit()); } else { - return CompletableFuture.failedFuture(new RuntimeException("Process not started")); + logger.warn("Process not started"); + return Mono.empty(); } })).doOnNext(process -> { if (process.exitValue() != 0) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java index e2d354f4a..46aefafcb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java @@ -270,8 +270,10 @@ public Mono sendNotification(String method, Map params) { */ @Override public Mono closeGracefully() { - this.connection.dispose(); - return transport.closeGracefully(); + return Mono.defer(() -> { + this.connection.dispose(); + return transport.closeGracefully(); + }); } /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 39bc49953..720388545 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -6,8 +6,12 @@ import java.time.Duration; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; @@ -45,10 +49,6 @@ // KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpAsyncClientTests { - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; abstract protected ClientMcpTransport createMcpTransport(); @@ -67,285 +67,326 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpAsyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) + McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @BeforeEach + void setUp() { + onStart(); } @AfterEach void tearDown() { - if (mcpAsyncClient != null) { - StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); - } onClose(); } + void verifyInitializationTimeout(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); - }).verify(); + verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); } @Test void testListTools() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) - .consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }) - .verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); } @Test void testPingWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.ping()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before pinging the server")) - .verify(); + verifyInitializationTimeout(client -> client.ping(), "pinging the server"); } @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before calling tools")) - .verify(); + verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) - .consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull(); - assertThat(callToolResult.content()).isNotNull(); - assertThat(callToolResult.isError()).isNull(); - }) - .verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) - .expectError(Exception.class) - .verify(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); } @Test void testListResourcesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResources(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resources")) - .verify(); + verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); } @Test void testListResources() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) - .consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); } @Test void testListPromptsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listPrompts(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing prompts")) - .verify(); + verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); } @Test void testListPrompts() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) - .consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.create(mcpAsyncClient.getPrompt(request)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before getting prompts")) - .verify(); + verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); } @Test void testGetPrompt() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); } @Test void testRootsListChangedWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) - .expectErrorMatches(error -> error instanceof McpError && error.getMessage() - .equals("Client must be initialized before sending roots list changed notification")) - .verify(); + verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - - StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); } @Test void testAddRootWithNullValue() { - StepVerifier.create(mcpAsyncClient.addRoot(null)) - .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .verify(); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); - StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); } @Test void testRemoveNonExistentRoot() { - StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); } @Test @Disabled void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + }).verifyComplete(); + } + }).verifyComplete(); + }); } @Test void testListResourceTemplatesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResourceTemplates()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resource templates")) - .verify(); + verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); } @Test void testListResourceTemplates() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) - .consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); } // @Test void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); } @Test @@ -354,36 +395,44 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + + var transport = createMcpTransport(); + var client = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer( + prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) + .build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - var capabilities = ClientCapabilities.builder() .experimental(Map.of("feature", "test")) .roots(true) @@ -392,18 +441,14 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - StepVerifier.create(client.initialize()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - }).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + client -> - StepVerifier.create(client.closeGracefully()).verifyComplete(); + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); } // --------------------------------------- @@ -412,43 +457,46 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before setting logging level")) - .verify(); + verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test void testLoggingLevels() { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); + withClient(createMcpTransport(), mcpAsyncClient -> { + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier.create(testAllLevels).verifyComplete(); + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) - .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 52a0138f6..1c042bf24 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -7,6 +7,9 @@ import java.time.Duration; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; @@ -27,6 +30,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; @@ -41,12 +48,8 @@ // KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpSyncClientTests { - private McpSyncClient mcpSyncClient; - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - protected ClientMcpTransport mcpTransport; - abstract protected ClientMcpTransport createMcpTransport(); protected void onStart() { @@ -63,254 +66,322 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpSyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) + McpClient.SyncSpec builder = McpClient.sync(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); + } + } + + @BeforeEach + void setUp() { + onStart(); + } @AfterEach void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); - } onClose(); } + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationTimesOut(Consumer operation, String action) { + verifyCallTimesOut(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallTimesOut(Function operation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + // This scheduler is not replaced by virtual time scheduler + Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); + + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> operation.apply(mcpSyncClient)) + // offload the blocking call to the real scheduler + .subscribeOn(customScheduler)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + + customScheduler.dispose(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listTools(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + verifyCallTimesOut(client -> client.listTools(null), "listing tools"); } @Test void testListTools() { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(null); - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); }); } @Test void testCallToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)))) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), + "calling tools"); } @Test void testCallTools() { - mcpSyncClient.initialize(); - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - assertThat(toolResult).isNotNull().satisfies(result -> { + assertThat(toolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).hasSize(1); + assertThat(result.content()).hasSize(1); - TextContent content = (TextContent) result.content().get(0); + TextContent content = (TextContent) result.content().get(0); - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); }); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.ping()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + verifyCallTimesOut(client -> client.ping(), "pinging the server"); } @Test void testPing() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(callToolRequest)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - mcpSyncClient.initialize(); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); } @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.rootsListChangedNotification()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResources(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + verifyCallTimesOut(client -> client.listResources(null), "listing resources"); } @Test void testListResources() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); }); } @Test void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); } @Test void testReadResourceWithoutInitialization() { - assertThatThrownBy(() -> { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - mcpSyncClient.readResource(resource); - }).isInstanceOf(McpError.class).hasMessage("Client must be initialized before reading resources"); + Resource resource = new Resource("test://uri", "Test Resource", null, null, null); + verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); } @Test void testReadResource() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + ReadResourceResult result = mcpSyncClient.readResource(firstResource); - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + } + }); } @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResourceTemplates(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); } @Test void testListResourceTemplates() { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); } // @Test void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); } @Test @@ -319,18 +390,17 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + client -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } // --------------------------------------- @@ -339,40 +409,37 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test void testLoggingLevels() { - mcpSyncClient.initialize(); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 6d759b4ba..ebf10b9a3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -5,6 +5,8 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import io.modelcontextprotocol.client.transport.ServerParameters; @@ -13,6 +15,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; @@ -35,15 +38,26 @@ protected ClientMcpTransport createMcpTransport() { } @Test - void customErrorHandlerShouldReceiveErrors() { + void customErrorHandlerShouldReceiveErrors() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); AtomicReference receivedError = new AtomicReference<>(); - ((StdioClientTransport) mcpTransport).setStdErrorHandler(error -> receivedError.set(error)); + ClientMcpTransport transport = createMcpTransport(); + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + ((StdioClientTransport) transport).setStdErrorHandler(error -> { + receivedError.set(error); + latch.countDown(); + }); String errorMessage = "Test error"; - ((StdioClientTransport) mcpTransport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); + ((StdioClientTransport) transport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); + + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); + + StepVerifier.create(transport.closeGracefully()).expectComplete().verify(Duration.ofSeconds(5)); } protected Duration getInitializationTimeout() {