From 6aa993df6129d310628d6f0956e525ecfea25514 Mon Sep 17 00:00:00 2001 From: Tyler Kendrick <145080887+Tyler-R-Kendrick@users.noreply.github.com> Date: Tue, 1 Apr 2025 09:40:11 -0400 Subject: [PATCH 1/2] Extend progress notification support --- src/ModelContextProtocol/Client/McpClient.cs | 2 +- .../Client/McpClientExtensions.cs | 25 +++++- .../ClientTokenProgress.cs | 28 +++++++ .../Protocol/Types/Capabilities.cs | 2 +- .../Types/ListPromptsRequestParams.cs | 10 +-- .../ListResourceTemplatesRequestParams.cs | 10 +-- .../Types/ListResourcesRequestParams.cs | 10 +-- .../Protocol/Types/ListRootsRequestParams.cs | 9 +-- .../Protocol/Types/ListToolsRequestParams.cs | 10 +-- .../Protocol/Types/PaginatedRequest.cs | 15 ++++ .../Server/AIFunctionMcpServerTool.cs | 2 +- src/ModelContextProtocol/Server/McpServer.cs | 1 - .../Server/McpServerExtensions.cs | 20 +++++ .../ServerTokenProgress.cs | 27 +++++++ .../ClientIntegrationTests.cs | 7 +- .../Server/McpServerTests.cs | 81 +++++++++++++++++++ .../SseIntegrationTests.cs | 2 +- .../SseServerIntegrationTests.cs | 2 +- 18 files changed, 207 insertions(+), 56 deletions(-) create mode 100644 src/ModelContextProtocol/ClientTokenProgress.cs create mode 100644 src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs create mode 100644 src/ModelContextProtocol/ServerTokenProgress.cs diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index b326f3c58..3640d4f57 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -42,7 +42,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp SetRequestHandler( RequestMethods.SamplingCreateMessage, - (request, ct) => samplingHandler(request, ct)); + (request, ct) => samplingHandler(request, new ClientTokenProgress(this, request?.Meta?.ProgressToken), ct)); } if (options.Capabilities?.Roots is { } rootsCapability) diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 36742b14a..448ed5fd2 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -531,17 +531,34 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat /// /// The with which to satisfy sampling requests. /// The created handler delegate. - public static Func> CreateSamplingHandler(this IChatClient chatClient) + public static Func, CancellationToken, Task> CreateSamplingHandler( + this IChatClient chatClient) { Throw.IfNull(chatClient); - return async (requestParams, cancellationToken) => + return async (requestParams, progress, cancellationToken) => { Throw.IfNull(requestParams); var (messages, options) = requestParams.ToChatClientArguments(); - var response = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - return response.ToCreateMessageResult(); + var progressToken = requestParams.Meta?.ProgressToken; + int progressValue = 0; + var streamingResponses = chatClient.GetStreamingResponseAsync( + messages, options, cancellationToken); + List updates = []; + await foreach (var streamingResponse in streamingResponses) + { + updates.Add(streamingResponse); + if (progressToken is not null) + { + progress.Report(new() + { + Progress = ++progressValue, + }); + } + } + + return updates.ToChatResponse().ToCreateMessageResult(); }; } diff --git a/src/ModelContextProtocol/ClientTokenProgress.cs b/src/ModelContextProtocol/ClientTokenProgress.cs new file mode 100644 index 000000000..86201ac32 --- /dev/null +++ b/src/ModelContextProtocol/ClientTokenProgress.cs @@ -0,0 +1,28 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; + +namespace ModelContextProtocol; + +internal sealed class ClientTokenProgress(IMcpClient client, ProgressToken? progressToken) + : IProgress +{ + /// + public void Report(ProgressNotificationValue value) + { + if (progressToken is null) return; + _ = client.SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ProgressNotification, + Params = new ProgressNotification() + { + ProgressToken = progressToken.Value, + Progress = new() + { + Progress = value.Progress, + Total = value.Total, + Message = value.Message, + }, + }, + }, CancellationToken.None); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs index f33357783..c0cf41977 100644 --- a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs @@ -55,7 +55,7 @@ public class SamplingCapability /// Gets or sets the handler for sampling requests. [JsonIgnore] - public Func>? SamplingHandler { get; set; } + public Func, CancellationToken, Task>? SamplingHandler { get; set; } } /// diff --git a/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs index 419b6fceb..a5500d410 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of prompts and prompt templates the server has. /// See the schema for details /// -public class ListPromptsRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} +public class ListPromptsRequestParams : PaginatedRequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs index f4060dbd0..8a54f6e8e 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of resource templates the server has. /// See the schema for details /// -public class ListResourceTemplatesRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} \ No newline at end of file +public class ListResourceTemplatesRequestParams : PaginatedRequestParams; \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs index ad7f19b31..30bea5b87 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of resources the server has. /// See the schema for details /// -public class ListResourcesRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} +public class ListResourcesRequestParams : PaginatedRequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs index dae1b75c1..a5eec7a15 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs @@ -6,11 +6,4 @@ namespace ModelContextProtocol.Protocol.Types; /// A request from the server to get a list of root URIs from the client. /// See the schema for details /// -public class ListRootsRequestParams -{ - /// - /// Optional progress token for out-of-band progress notifications. - /// - [System.Text.Json.Serialization.JsonPropertyName("progressToken")] - public ProgressToken? ProgressToken { get; init; } -} +public class ListRootsRequestParams : RequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs index 4f18fbb73..64ac18599 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of tools the server has. /// See the schema for details /// -public class ListToolsRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} +public class ListToolsRequestParams : PaginatedRequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs b/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs new file mode 100644 index 000000000..abf47dd3c --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs @@ -0,0 +1,15 @@ +namespace ModelContextProtocol.Protocol.Types; + +/// +/// Used as a base class for paginated requests. +/// See the schema for details +/// +public class PaginatedRequestParams : RequestParams +{ + /// + /// An opaque token representing the current pagination position. + /// If provided, the server should return results starting after this cursor. + /// + [System.Text.Json.Serialization.JsonPropertyName("cursor")] + public string? Cursor { get; init; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 03d01ee42..ee3b82728 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -111,7 +111,7 @@ private static TemporaryAIFunctionFactoryOptions CreateAIFunctionFactoryOptions( if (requestContent?.Server is { } server && requestContent?.Params?.Meta?.ProgressToken is { } progressToken) { - return new TokenProgress(server, progressToken); + return new ServerTokenProgress(server, progressToken); } return NullProgress.Instance; diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 8aa32965e..aa35ac339 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -68,7 +68,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? }); SetToolsHandler(options); - SetInitializeHandler(options); SetCompletionHandler(options); SetPingHandler(); diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 3b541ec80..84a5abab5 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -169,6 +169,26 @@ public static Task RequestRootsAsync( cancellationToken); } + /// + /// Requests the client to list the roots it exposes. + /// + /// The server issueing the request. + /// The notification to send. + /// A token to cancel the operation. + /// A task containing the response from the client. + public static Task NotifyProgressAsync( + this IMcpServer server, + ProgressNotification notification, + CancellationToken cancellationToken = default) + { + Throw.IfNull(server); + return server.SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ProgressNotification, + Params = notification, + }, cancellationToken); + } + /// Provides an implementation that's implemented via client sampling. /// private sealed class SamplingChatClient(IMcpServer server) : IChatClient diff --git a/src/ModelContextProtocol/ServerTokenProgress.cs b/src/ModelContextProtocol/ServerTokenProgress.cs new file mode 100644 index 000000000..4780681dc --- /dev/null +++ b/src/ModelContextProtocol/ServerTokenProgress.cs @@ -0,0 +1,27 @@ +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol; + +/// +/// Provides an tied to a specific progress token and that will issue +/// progress notifications to the supplied endpoint. +/// +internal sealed class ServerTokenProgress(IMcpServer server, ProgressToken progressToken) + : IProgress +{ + /// + public void Report(ProgressNotificationValue value) + { + _ = server.NotifyProgressAsync(new() + { + ProgressToken = progressToken, + Progress = new() + { + Progress = value.Progress, + Total = value.Total, + Message = value.Message, + }, + }, CancellationToken.None); + } +} diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 9b598d332..a08bbb51d 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -355,7 +355,7 @@ public async Task Sampling_Stdio(string clientId) { Sampling = new() { - SamplingHandler = (_, _) => + SamplingHandler = (_, _, _) => { samplingHandlerCalls++; return Task.FromResult(new CreateMessageResult @@ -511,6 +511,9 @@ public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() [Fact(Skip = "Requires OpenAI API Key", SkipWhen = nameof(NoOpenAIKeySet))] public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() { + var samplingHandler = new OpenAIClient(s_openAIKey) + .AsChatClient("gpt-4o-mini") + .CreateSamplingHandler(); await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, new() { ClientInfo = new() { Name = nameof(SamplingViaChatClient_RequestResponseProperlyPropagated), Version = "1.0.0" }, @@ -518,7 +521,7 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() { Sampling = new() { - SamplingHandler = new OpenAIClient(s_openAIKey).AsChatClient("gpt-4o-mini").CreateSamplingHandler(), + SamplingHandler = samplingHandler, }, }, }, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index a3290049c..43495f356 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -635,4 +635,85 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella public Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); } + + [Fact] + public async Task NotifyProgress_Should_Be_Handled() + { + var taskCompletionSource = new TaskCompletionSource(); + bool notificationHandled = false; + await Notifications_Are_Handled( + serverCapabilities: null, + method: NotificationMethods.ProgressNotification, + parameters: new ProgressNotification() + { + ProgressToken = new(), + Progress = new() + { + Progress = 50, + Total = 100, + Message = "Progress message", + }, + }, + configureOptions: null, + configureServer: server => + { + server.AddNotificationHandler(NotificationMethods.ProgressNotification, + (notification) => + { + notificationHandled = true; + var progress = (ProgressNotificationValue?)notification.Params; + Assert.NotNull(progress); + var progressValue = progress.Value; + taskCompletionSource.SetResult(); + Assert.Equal(50, progressValue.Progress); + Assert.Equal(100, progressValue.Total); + Assert.Equal("Progress message", progressValue.Message); + return Task.CompletedTask; + }); + }, + assertResult: async response => + { + //Note: awaiting here so handlers are guaranteed to be called first. + await taskCompletionSource.Task.WaitAsync(TimeSpan.FromSeconds(1)); + Assert.True(notificationHandled); + }); + } + + private async Task Notifications_Are_Handled( + ServerCapabilities? serverCapabilities, + string method, object? parameters, + Action? configureOptions, + Action? configureServer, + Action assertResult) + { + await using TestServerTransport transport = new(); + var options = CreateOptions(serverCapabilities); + configureOptions?.Invoke(options); + + await using var server = McpServerFactory.Create( + transport, options, LoggerFactory, _serviceProvider); + + configureServer?.Invoke(server); + await server.RunAsync(); + + TaskCompletionSource receivedMessage = new(); + + transport.OnMessageSent = (message) => + { + Assert.NotNull(message); + if (message is JsonRpcNotification notification && notification.Method == method) + { + assertResult(notification); + receivedMessage.SetResult(notification); + } + }; + + await transport.SendMessageAsync(new JsonRpcNotification + { + Method = method, + Params = parameters, + }); + + var response = await receivedMessage.Task.WaitAsync(TimeSpan.FromSeconds(1)); + } } diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 1874953d9..2d9ba133e 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -116,7 +116,7 @@ public async Task Sampling_Sse_EverythingServer() { Sampling = new() { - SamplingHandler = (_, _) => + SamplingHandler = (_, _, _) => { samplingHandlerCalls++; return Task.FromResult(new CreateMessageResult diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index 44befcd10..4fad0d5bc 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -215,7 +215,7 @@ public async Task Sampling_Sse_TestServer() var options = SseServerIntegrationTestFixture.CreateDefaultClientOptions(); options.Capabilities ??= new(); options.Capabilities.Sampling ??= new(); - options.Capabilities.Sampling.SamplingHandler = async (_, _) => + options.Capabilities.Sampling.SamplingHandler = async (_, _, _) => { samplingHandlerCalls++; return new CreateMessageResult From e9faef765a3cd80b109ce4ce7a8d2ba0a324e585 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 1 Apr 2025 11:24:49 -0400 Subject: [PATCH 2/2] Address feedback and fix test --- src/ModelContextProtocol/Client/IMcpClient.cs | 41 +------- src/ModelContextProtocol/Client/McpClient.cs | 7 +- .../Client/McpClientExtensions.cs | 15 ++- .../ClientTokenProgress.cs | 28 ------ src/ModelContextProtocol/IMcpEndpoint.cs | 35 +++++++ .../McpEndpointExtensions.cs | 34 +++++++ .../Protocol/Types/Tool.cs | 2 +- .../Server/AIFunctionMcpServerTool.cs | 2 +- src/ModelContextProtocol/Server/IMcpServer.cs | 39 +------- .../Server/McpServerExtensions.cs | 22 +---- .../ServerTokenProgress.cs | 27 ------ src/ModelContextProtocol/TokenProgress.cs | 20 +--- .../Server/McpServerTests.cs | 94 ++++++------------- 13 files changed, 118 insertions(+), 248 deletions(-) delete mode 100644 src/ModelContextProtocol/ClientTokenProgress.cs create mode 100644 src/ModelContextProtocol/IMcpEndpoint.cs create mode 100644 src/ModelContextProtocol/McpEndpointExtensions.cs delete mode 100644 src/ModelContextProtocol/ServerTokenProgress.cs diff --git a/src/ModelContextProtocol/Client/IMcpClient.cs b/src/ModelContextProtocol/Client/IMcpClient.cs index 6761c4ef9..357ce3843 100644 --- a/src/ModelContextProtocol/Client/IMcpClient.cs +++ b/src/ModelContextProtocol/Client/IMcpClient.cs @@ -1,12 +1,11 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Types; namespace ModelContextProtocol.Client; /// /// Represents an instance of an MCP client connecting to a specific server. /// -public interface IMcpClient : IAsyncDisposable +public interface IMcpClient : IMcpEndpoint { /// /// Gets the capabilities supported by the server. @@ -24,40 +23,4 @@ public interface IMcpClient : IAsyncDisposable /// It can be thought of like a "hint" to the model. For example, this information MAY be added to the system prompt. /// string? ServerInstructions { get; } - - /// - /// Adds a handler for server notifications of a specific method. - /// - /// The notification method to handle. - /// The async handler function to process notifications. - /// - /// - /// Each method may have multiple handlers. Adding a handler for a method that already has one - /// will not replace the existing handler. - /// - /// - /// provides constants for common notification methods. - /// - /// - void AddNotificationHandler(string method, Func handler); - - /// - /// Sends a generic JSON-RPC request to the server. - /// - /// The expected response type. - /// The JSON-RPC request to send. - /// A token to cancel the operation. - /// A task containing the server's response. - /// - /// It is recommended to use the capability-specific methods that use this one in their implementation. - /// Use this method for custom requests or those not yet covered explicitly. - /// - Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class; - - /// - /// Sends a message to the server. - /// - /// The message. - /// A token to cancel the operation. - Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); } \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 3640d4f57..773301c01 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -42,7 +42,10 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp SetRequestHandler( RequestMethods.SamplingCreateMessage, - (request, ct) => samplingHandler(request, new ClientTokenProgress(this, request?.Meta?.ProgressToken), ct)); + (request, cancellationToken) => samplingHandler( + request, + request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken)); } if (options.Capabilities?.Roots is { } rootsCapability) @@ -54,7 +57,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp SetRequestHandler( RequestMethods.RootsList, - (request, ct) => rootsHandler(request, ct)); + (request, cancellationToken) => rootsHandler(request, cancellationToken)); } } diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 448ed5fd2..4d1411742 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -8,9 +8,7 @@ namespace ModelContextProtocol.Client; -/// -/// Provides extensions for operating on MCP clients. -/// +/// Provides extension methods for interacting with an . public static class McpClientExtensions { /// @@ -542,18 +540,17 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat var (messages, options) = requestParams.ToChatClientArguments(); var progressToken = requestParams.Meta?.ProgressToken; - int progressValue = 0; - var streamingResponses = chatClient.GetStreamingResponseAsync( - messages, options, cancellationToken); + List updates = []; - await foreach (var streamingResponse in streamingResponses) + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken)) { - updates.Add(streamingResponse); + updates.Add(update); + if (progressToken is not null) { progress.Report(new() { - Progress = ++progressValue, + Progress = updates.Count, }); } } diff --git a/src/ModelContextProtocol/ClientTokenProgress.cs b/src/ModelContextProtocol/ClientTokenProgress.cs deleted file mode 100644 index 86201ac32..000000000 --- a/src/ModelContextProtocol/ClientTokenProgress.cs +++ /dev/null @@ -1,28 +0,0 @@ -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol.Messages; - -namespace ModelContextProtocol; - -internal sealed class ClientTokenProgress(IMcpClient client, ProgressToken? progressToken) - : IProgress -{ - /// - public void Report(ProgressNotificationValue value) - { - if (progressToken is null) return; - _ = client.SendMessageAsync(new JsonRpcNotification() - { - Method = NotificationMethods.ProgressNotification, - Params = new ProgressNotification() - { - ProgressToken = progressToken.Value, - Progress = new() - { - Progress = value.Progress, - Total = value.Total, - Message = value.Message, - }, - }, - }, CancellationToken.None); - } -} \ No newline at end of file diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs new file mode 100644 index 000000000..d7dbd211e --- /dev/null +++ b/src/ModelContextProtocol/IMcpEndpoint.cs @@ -0,0 +1,35 @@ +using ModelContextProtocol.Protocol.Messages; + +namespace ModelContextProtocol; + +/// Represents a client or server MCP endpoint. +public interface IMcpEndpoint : IAsyncDisposable +{ + /// Sends a generic JSON-RPC request to the connected endpoint. + /// The expected response type. + /// The JSON-RPC request to send. + /// A token to cancel the operation. + /// A task containing the client's response. + Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class; + + /// Sends a message to the connected endpoint. + /// The message. + /// A token to cancel the operation. + Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); + + /// + /// Adds a handler for server notifications of a specific method. + /// + /// The notification method to handle. + /// The async handler function to process notifications. + /// + /// + /// Each method may have multiple handlers. Adding a handler for a method that already has one + /// will not replace the existing handler. + /// + /// + /// provides constants for common notification methods. + /// + /// + void AddNotificationHandler(string method, Func handler); +} diff --git a/src/ModelContextProtocol/McpEndpointExtensions.cs b/src/ModelContextProtocol/McpEndpointExtensions.cs new file mode 100644 index 000000000..c27ada2b2 --- /dev/null +++ b/src/ModelContextProtocol/McpEndpointExtensions.cs @@ -0,0 +1,34 @@ +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; + +namespace ModelContextProtocol; + +/// Provides extension methods for interacting with an . +public static class McpEndpointExtensions +{ + /// Notifies the connected endpoint of progress. + /// The endpoint issueing the notification. + /// The identifying the operation. + /// The progress update to send. + /// A token to cancel the operation. + /// A task representing the completion of the operation. + /// is . + public static Task NotifyProgressAsync( + this IMcpEndpoint endpoint, + ProgressToken progressToken, + ProgressNotificationValue progress, + CancellationToken cancellationToken = default) + { + Throw.IfNull(endpoint); + + return endpoint.SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ProgressNotification, + Params = new ProgressNotification() + { + ProgressToken = progressToken, + Progress = progress, + }, + }, cancellationToken); + } +} diff --git a/src/ModelContextProtocol/Protocol/Types/Tool.cs b/src/ModelContextProtocol/Protocol/Types/Tool.cs index ed0c71290..dc0b774c0 100644 --- a/src/ModelContextProtocol/Protocol/Types/Tool.cs +++ b/src/ModelContextProtocol/Protocol/Types/Tool.cs @@ -38,7 +38,7 @@ public JsonElement InputSchema { if (!McpJsonUtilities.IsValidMcpToolSchema(value)) { - throw new ArgumentException("The specified document is not a valid MPC tool JSON schema.", nameof(InputSchema)); + throw new ArgumentException("The specified document is not a valid MCP tool JSON schema.", nameof(InputSchema)); } _inputSchema = value; diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index ee3b82728..03d01ee42 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -111,7 +111,7 @@ private static TemporaryAIFunctionFactoryOptions CreateAIFunctionFactoryOptions( if (requestContent?.Server is { } server && requestContent?.Params?.Meta?.ProgressToken is { } progressToken) { - return new ServerTokenProgress(server, progressToken); + return new TokenProgress(server, progressToken); } return NullProgress.Instance; diff --git a/src/ModelContextProtocol/Server/IMcpServer.cs b/src/ModelContextProtocol/Server/IMcpServer.cs index e8dffaf19..19b3967ad 100644 --- a/src/ModelContextProtocol/Server/IMcpServer.cs +++ b/src/ModelContextProtocol/Server/IMcpServer.cs @@ -1,12 +1,11 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Types; namespace ModelContextProtocol.Server; /// /// Represents a server that can communicate with a client using the MCP protocol. /// -public interface IMcpServer : IAsyncDisposable +public interface IMcpServer : IMcpEndpoint { /// /// Gets the capabilities supported by the client. @@ -26,42 +25,8 @@ public interface IMcpServer : IAsyncDisposable /// IServiceProvider? Services { get; } - /// - /// Adds a handler for client notifications of a specific method. - /// - /// The notification method to handle. - /// The async handler function to process notifications. - /// - /// - /// Each method may have multiple handlers. Adding a handler for a method that already has one - /// will not replace the existing handler. - /// - /// - /// provides constants for common notification methods. - /// - /// - void AddNotificationHandler(string method, Func handler); - /// /// Runs the server, listening for and handling client requests. /// Task RunAsync(CancellationToken cancellationToken = default); - - /// - /// Sends a generic JSON-RPC request to the client. - /// NB! This is a temporary method that is available to send not yet implemented feature messages. - /// Once all MCP features are implemented this will be made private, as it is purely a convenience for those who wish to implement features ahead of the library. - /// - /// The expected response type. - /// The JSON-RPC request to send. - /// A token to cancel the operation. - /// A task containing the client's response. - Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class; - - /// - /// Sends a message to the client. - /// - /// The message. - /// A token to cancel the operation. - Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 84a5abab5..ddaf45f28 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Server; -/// +/// Provides extension methods for interacting with an . public static class McpServerExtensions { /// @@ -169,26 +169,6 @@ public static Task RequestRootsAsync( cancellationToken); } - /// - /// Requests the client to list the roots it exposes. - /// - /// The server issueing the request. - /// The notification to send. - /// A token to cancel the operation. - /// A task containing the response from the client. - public static Task NotifyProgressAsync( - this IMcpServer server, - ProgressNotification notification, - CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - return server.SendMessageAsync(new JsonRpcNotification() - { - Method = NotificationMethods.ProgressNotification, - Params = notification, - }, cancellationToken); - } - /// Provides an implementation that's implemented via client sampling. /// private sealed class SamplingChatClient(IMcpServer server) : IChatClient diff --git a/src/ModelContextProtocol/ServerTokenProgress.cs b/src/ModelContextProtocol/ServerTokenProgress.cs deleted file mode 100644 index 4780681dc..000000000 --- a/src/ModelContextProtocol/ServerTokenProgress.cs +++ /dev/null @@ -1,27 +0,0 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; - -namespace ModelContextProtocol; - -/// -/// Provides an tied to a specific progress token and that will issue -/// progress notifications to the supplied endpoint. -/// -internal sealed class ServerTokenProgress(IMcpServer server, ProgressToken progressToken) - : IProgress -{ - /// - public void Report(ProgressNotificationValue value) - { - _ = server.NotifyProgressAsync(new() - { - ProgressToken = progressToken, - Progress = new() - { - Progress = value.Progress, - Total = value.Total, - Message = value.Message, - }, - }, CancellationToken.None); - } -} diff --git a/src/ModelContextProtocol/TokenProgress.cs b/src/ModelContextProtocol/TokenProgress.cs index 46af03f4f..62834e75a 100644 --- a/src/ModelContextProtocol/TokenProgress.cs +++ b/src/ModelContextProtocol/TokenProgress.cs @@ -1,30 +1,16 @@ using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; namespace ModelContextProtocol; /// /// Provides an tied to a specific progress token and that will issue -/// progress notifications to the supplied endpoint. +/// progress notifications on the supplied endpoint. /// -internal sealed class TokenProgress(IMcpServer server, ProgressToken progressToken) : IProgress +internal sealed class TokenProgress(IMcpEndpoint endpoint, ProgressToken progressToken) : IProgress { /// public void Report(ProgressNotificationValue value) { - _ = server.SendMessageAsync(new JsonRpcNotification() - { - Method = NotificationMethods.ProgressNotification, - Params = new ProgressNotification() - { - ProgressToken = progressToken, - Progress = new() - { - Progress = value.Progress, - Total = value.Total, - Message = value.Message, - }, - }, - }, CancellationToken.None); + _ = endpoint.NotifyProgressAsync(progressToken, value, CancellationToken.None); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 43495f356..bded8e990 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -639,14 +639,26 @@ public Task RunAsync(CancellationToken cancellationToken = default) => [Fact] public async Task NotifyProgress_Should_Be_Handled() { - var taskCompletionSource = new TaskCompletionSource(); - bool notificationHandled = false; - await Notifications_Are_Handled( - serverCapabilities: null, - method: NotificationMethods.ProgressNotification, - parameters: new ProgressNotification() + await using TestServerTransport transport = new(); + var options = CreateOptions(); + + var notificationReceived = new TaskCompletionSource(); + + var server = McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider); + server.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => + { + notificationReceived.SetResult(notification); + return Task.CompletedTask; + }); + + Task serverTask = server.RunAsync(TestContext.Current.CancellationToken); + + await transport.SendMessageAsync(new JsonRpcNotification + { + Method = NotificationMethods.ProgressNotification, + Params = new ProgressNotification() { - ProgressToken = new(), + ProgressToken = new("abc"), Progress = new() { Progress = 50, @@ -654,66 +666,16 @@ await Notifications_Are_Handled( Message = "Progress message", }, }, - configureOptions: null, - configureServer: server => - { - server.AddNotificationHandler(NotificationMethods.ProgressNotification, - (notification) => - { - notificationHandled = true; - var progress = (ProgressNotificationValue?)notification.Params; - Assert.NotNull(progress); - var progressValue = progress.Value; - taskCompletionSource.SetResult(); - Assert.Equal(50, progressValue.Progress); - Assert.Equal(100, progressValue.Total); - Assert.Equal("Progress message", progressValue.Message); - return Task.CompletedTask; - }); - }, - assertResult: async response => - { - //Note: awaiting here so handlers are guaranteed to be called first. - await taskCompletionSource.Task.WaitAsync(TimeSpan.FromSeconds(1)); - Assert.True(notificationHandled); - }); - } - - private async Task Notifications_Are_Handled( - ServerCapabilities? serverCapabilities, - string method, object? parameters, - Action? configureOptions, - Action? configureServer, - Action assertResult) - { - await using TestServerTransport transport = new(); - var options = CreateOptions(serverCapabilities); - configureOptions?.Invoke(options); - - await using var server = McpServerFactory.Create( - transport, options, LoggerFactory, _serviceProvider); - - configureServer?.Invoke(server); - await server.RunAsync(); - - TaskCompletionSource receivedMessage = new(); - - transport.OnMessageSent = (message) => - { - Assert.NotNull(message); - if (message is JsonRpcNotification notification && notification.Method == method) - { - assertResult(notification); - receivedMessage.SetResult(notification); - } - }; + }, TestContext.Current.CancellationToken); - await transport.SendMessageAsync(new JsonRpcNotification - { - Method = method, - Params = parameters, - }); + var notification = await notificationReceived.Task; + var progress = (ProgressNotification)notification.Params!; + Assert.Equal("\"abc\"", progress.ProgressToken.ToString()); + Assert.Equal(50, progress.Progress.Progress); + Assert.Equal(100, progress.Progress.Total); + Assert.Equal("Progress message", progress.Progress.Message); - var response = await receivedMessage.Task.WaitAsync(TimeSpan.FromSeconds(1)); + await server.DisposeAsync(); + await serverTask; } }