diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index cf8142c68..ed98962f1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -100,6 +100,9 @@ public class McpAsyncClient { public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { }; + public static final TypeReference PROGRESS_NOTIFICATION_TYPE_REF = new TypeReference<>() { + }; + /** * Client capabilities. */ @@ -253,6 +256,16 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); + // Utility Progress Notification + List>> progressConsumersFinal = new ArrayList<>(); + progressConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Progress: {}", notification))); + if (!Utils.isEmpty(features.progressConsumers())) { + progressConsumersFinal.addAll(features.progressConsumers()); + } + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS, + asyncProgressNotificationHandler(progressConsumersFinal)); + this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, List.of(McpSchema.LATEST_PROTOCOL_VERSION), initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers, @@ -828,6 +841,19 @@ private NotificationHandler asyncLoggingNotificationHandler( }; } + private NotificationHandler asyncProgressNotificationHandler( + List>> progressConsumers) { + + return params -> { + McpSchema.ProgressNotification progressNotification = transport.unmarshalFrom(params, + PROGRESS_NOTIFICATION_TYPE_REF); + + return Flux.fromIterable(progressConsumers) + .flatMap(consumer -> consumer.apply(progressNotification)) + .then(); + }; + } + /** * Sets the minimum logging level for messages received from the server. The client * will only receive log messages at or above the specified severity level. diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index d8925b005..ce603a0fa 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -177,6 +177,8 @@ class SyncSpec { private final List> loggingConsumers = new ArrayList<>(); + private final List> progressConsumers = new ArrayList<>(); + private Function samplingHandler; private Function elicitationHandler; @@ -377,6 +379,36 @@ public SyncSpec loggingConsumers(List progressConsumer) { + Assert.notNull(progressConsumer, "Progress consumer must not be null"); + this.progressConsumers.add(progressConsumer); + return this; + } + + /** + * Adds a multiple consumers to be notified of progress notifications from the + * server. This allows the client to track long-running operations and provide + * feedback to users. + * @param progressConsumers A list of consumers that receives progress + * notifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if progressConsumer is null + */ + public SyncSpec progressConsumers(List> progressConsumers) { + Assert.notNull(progressConsumers, "Progress consumers must not be null"); + this.progressConsumers.addAll(progressConsumers); + return this; + } + /** * Create an instance of {@link McpSyncClient} with the provided configurations or * sensible defaults. @@ -385,7 +417,8 @@ public SyncSpec loggingConsumers(List>> loggingConsumers = new ArrayList<>(); + private final List>> progressConsumers = new ArrayList<>(); + private Function> samplingHandler; private Function> elicitationHandler; @@ -663,8 +698,8 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, - this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler, - this.elicitationHandler)); + this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, + this.samplingHandler, this.elicitationHandler)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index bd1a0985a..ed3f48b9b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -59,6 +59,7 @@ class McpClientFeatures { * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. */ @@ -68,6 +69,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesUpdateConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, + List>> progressConsumers, Function> samplingHandler, Function> elicitationHandler) { @@ -79,6 +81,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progressconsumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. */ @@ -89,6 +92,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesUpdateConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, + List>> progressConsumers, Function> samplingHandler, Function> elicitationHandler) { @@ -106,6 +110,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of(); this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); + this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; } @@ -149,6 +154,12 @@ public static Async fromSync(Sync syncSpec) { .subscribeOn(Schedulers.boundedElastic())); } + List>> progressConsumers = new ArrayList<>(); + for (Consumer consumer : syncSpec.progressConsumers()) { + progressConsumers.add(p -> Mono.fromRunnable(() -> consumer.accept(p)) + .subscribeOn(Schedulers.boundedElastic())); + } + Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); @@ -159,7 +170,7 @@ public static Async fromSync(Sync syncSpec) { return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, - loggingConsumers, samplingHandler, elicitationHandler); + loggingConsumers, progressConsumers, samplingHandler, elicitationHandler); } } @@ -174,6 +185,7 @@ public static Async fromSync(Sync syncSpec) { * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. */ @@ -183,6 +195,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili List>> resourcesUpdateConsumers, List>> promptsChangeConsumers, List> loggingConsumers, + List> progressConsumers, Function samplingHandler, Function elicitationHandler) { @@ -196,6 +209,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param resourcesUpdateConsumers the resource update consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. */ @@ -205,6 +219,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl List>> resourcesUpdateConsumers, List>> promptsChangeConsumers, List> loggingConsumers, + List> progressConsumers, Function samplingHandler, Function elicitationHandler) { @@ -222,6 +237,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of(); this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); + this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index e56c695fa..c6bb23586 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -177,6 +177,20 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN }); } + /** + * Sends a notification to the client that the current progress status has changed for + * long-running operations. + * @param progressNotification The progress notification to send + * @return A Mono that completes when the notification has been sent + */ + public Mono progressNotification(McpSchema.ProgressNotification progressNotification) { + if (progressNotification == null) { + return Mono.error(new McpError("Progress notification must not be null")); + } + + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, progressNotification); + } + /** * Sends a ping request to the client. * @return A Mono that completes with clients's ping response diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 3ce599c8b..12edfb341 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -209,7 +209,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * represents a specific capability. * * @param tool The tool definition including name, description, and parameter schema - * @param call Deprecated. Uset he {@link AsyncToolSpecification#callHandler} instead. + * @param call Deprecated. Use the {@link AsyncToolSpecification#callHandler} instead. * @param callHandler The function that implements the tool's logic, receiving a * {@link McpAsyncServerExchange} and a * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} and returning diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index acdc18e10..2a84802ba 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -23,6 +23,7 @@ import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.util.annotation.Nullable; /** * Based on the JSON-RPC 2.0 @@ -58,6 +59,8 @@ private McpSchema() { public static final String METHOD_PING = "ping"; + public static final String METHOD_NOTIFICATION_PROGRESS = "notifications/progress"; + // Tool Methods public static final String METHOD_TOOLS_LIST = "tools/list"; @@ -249,7 +252,7 @@ public record InitializeRequest( // @formatter:off @JsonProperty("capabilities") ClientCapabilities capabilities, @JsonProperty("clientInfo") Implementation clientInfo, @JsonProperty("_meta") Map meta) implements Request { - + public InitializeRequest(String protocolVersion, ClientCapabilities capabilities, Implementation clientInfo) { this(protocolVersion, capabilities, clientInfo, null); } @@ -462,7 +465,7 @@ public ServerCapabilities build() { public record Implementation(// @formatter:off @JsonProperty("name") String name, @JsonProperty("title") String title, - @JsonProperty("version") String version) implements BaseMetadata {// @formatter:on + @JsonProperty("version") String version) implements BaseMetadata {// @formatter:on public Implementation(String name, String version) { this(name, null, version); @@ -1053,6 +1056,8 @@ private static JsonSchema parseSchema(String schema) { * tools/list. * @param arguments Arguments to pass to the tool. These must conform to the tool's * input schema. + * @param meta Optional metadata about the request. This can include additional + * information like `progressToken` */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) @@ -1063,9 +1068,10 @@ public record CallToolRequest(// @formatter:off public CallToolRequest(String name, String jsonArguments) { this(name, parseJsonArguments(jsonArguments), null); - } - public CallToolRequest(String name, Map arguments) { - this(name, arguments, null); + } + + public CallToolRequest(String name, Map arguments) { + this(name, arguments, null); } private static Map parseJsonArguments(String jsonArguments) { @@ -1317,7 +1323,7 @@ public record CreateMessageRequest(// @formatter:off @JsonProperty("metadata") Map metadata, @JsonProperty("_meta") Map meta) implements Request { - + // backwards compatibility constructor public CreateMessageRequest(List messages, ModelPreferences modelPreferences, String systemPrompt, ContextInclusionStrategy includeContext, @@ -1771,7 +1777,7 @@ public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argumen public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argument) { this(ref, argument, null, null); } - + public record CompleteArgument( @JsonProperty("name") String name, @JsonProperty("value") String value) { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index ac10df4f5..514903e39 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -1070,4 +1070,75 @@ void testPingSuccess() { mcpServer.close(); } + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @Test + void testProgressNotification() { + // Create a list to store received logging notifications + List receivedNotifications = new ArrayList<>(); + + // Create server with a tool that sends logging notifications + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("progress-test", "Test progress notifications", emptyJsonSchema), null, + (exchange, request) -> { + var progressToken = (String) request.meta().get("progressToken"); + + exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.1, 1.0, "Test progress 1/10")) + .block(); + + exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Test progress 5/10")) + .block(); + + exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Test progress 10/10")) + .block(); + + return Mono.just(new CallToolResult("Progress test completed", false)); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(receivedNotifications::add).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + CallToolResult result = mcpClient.callTool( + new McpSchema.CallToolRequest("progress-test", Map.of(), Map.of("progressToken", "test-token"))); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + // Wait for notifications to be processed + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + + System.out.println("Received notifications: " + receivedNotifications); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(3); + + // Check the progress notifications + assertThat(receivedNotifications.stream().map(McpSchema.ProgressNotification::progressToken)) + .containsExactlyInAnyOrder("test-token", "test-token", "test-token"); + assertThat(receivedNotifications.stream().map(McpSchema.ProgressNotification::progress)) + .containsExactlyInAnyOrder(0.1, 0.5, 1.0); + }); + } + mcpServer.close(); + } + }