diff --git a/src/ModelContextProtocol/Client/IMcpClient.cs b/src/ModelContextProtocol/Client/IMcpClient.cs index 6761c4ef9..357ce3843 100644 --- a/src/ModelContextProtocol/Client/IMcpClient.cs +++ b/src/ModelContextProtocol/Client/IMcpClient.cs @@ -1,12 +1,11 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Types; namespace ModelContextProtocol.Client; /// /// Represents an instance of an MCP client connecting to a specific server. /// -public interface IMcpClient : IAsyncDisposable +public interface IMcpClient : IMcpEndpoint { /// /// Gets the capabilities supported by the server. @@ -24,40 +23,4 @@ public interface IMcpClient : IAsyncDisposable /// It can be thought of like a "hint" to the model. For example, this information MAY be added to the system prompt. /// string? ServerInstructions { get; } - - /// - /// Adds a handler for server notifications of a specific method. - /// - /// The notification method to handle. - /// The async handler function to process notifications. - /// - /// - /// Each method may have multiple handlers. Adding a handler for a method that already has one - /// will not replace the existing handler. - /// - /// - /// provides constants for common notification methods. - /// - /// - void AddNotificationHandler(string method, Func handler); - - /// - /// Sends a generic JSON-RPC request to the server. - /// - /// The expected response type. - /// The JSON-RPC request to send. - /// A token to cancel the operation. - /// A task containing the server's response. - /// - /// It is recommended to use the capability-specific methods that use this one in their implementation. - /// Use this method for custom requests or those not yet covered explicitly. - /// - Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class; - - /// - /// Sends a message to the server. - /// - /// The message. - /// A token to cancel the operation. - Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); } \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index b326f3c58..773301c01 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -42,7 +42,10 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp SetRequestHandler( RequestMethods.SamplingCreateMessage, - (request, ct) => samplingHandler(request, ct)); + (request, cancellationToken) => samplingHandler( + request, + request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken)); } if (options.Capabilities?.Roots is { } rootsCapability) @@ -54,7 +57,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp SetRequestHandler( RequestMethods.RootsList, - (request, ct) => rootsHandler(request, ct)); + (request, cancellationToken) => rootsHandler(request, cancellationToken)); } } diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 36742b14a..4d1411742 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -8,9 +8,7 @@ namespace ModelContextProtocol.Client; -/// -/// Provides extensions for operating on MCP clients. -/// +/// Provides extension methods for interacting with an . public static class McpClientExtensions { /// @@ -531,17 +529,33 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat /// /// The with which to satisfy sampling requests. /// The created handler delegate. - public static Func> CreateSamplingHandler(this IChatClient chatClient) + public static Func, CancellationToken, Task> CreateSamplingHandler( + this IChatClient chatClient) { Throw.IfNull(chatClient); - return async (requestParams, cancellationToken) => + return async (requestParams, progress, cancellationToken) => { Throw.IfNull(requestParams); var (messages, options) = requestParams.ToChatClientArguments(); - var response = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - return response.ToCreateMessageResult(); + var progressToken = requestParams.Meta?.ProgressToken; + + List updates = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken)) + { + updates.Add(update); + + if (progressToken is not null) + { + progress.Report(new() + { + Progress = updates.Count, + }); + } + } + + return updates.ToChatResponse().ToCreateMessageResult(); }; } diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs new file mode 100644 index 000000000..d7dbd211e --- /dev/null +++ b/src/ModelContextProtocol/IMcpEndpoint.cs @@ -0,0 +1,35 @@ +using ModelContextProtocol.Protocol.Messages; + +namespace ModelContextProtocol; + +/// Represents a client or server MCP endpoint. +public interface IMcpEndpoint : IAsyncDisposable +{ + /// Sends a generic JSON-RPC request to the connected endpoint. + /// The expected response type. + /// The JSON-RPC request to send. + /// A token to cancel the operation. + /// A task containing the client's response. + Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class; + + /// Sends a message to the connected endpoint. + /// The message. + /// A token to cancel the operation. + Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); + + /// + /// Adds a handler for server notifications of a specific method. + /// + /// The notification method to handle. + /// The async handler function to process notifications. + /// + /// + /// Each method may have multiple handlers. Adding a handler for a method that already has one + /// will not replace the existing handler. + /// + /// + /// provides constants for common notification methods. + /// + /// + void AddNotificationHandler(string method, Func handler); +} diff --git a/src/ModelContextProtocol/McpEndpointExtensions.cs b/src/ModelContextProtocol/McpEndpointExtensions.cs new file mode 100644 index 000000000..c27ada2b2 --- /dev/null +++ b/src/ModelContextProtocol/McpEndpointExtensions.cs @@ -0,0 +1,34 @@ +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; + +namespace ModelContextProtocol; + +/// Provides extension methods for interacting with an . +public static class McpEndpointExtensions +{ + /// Notifies the connected endpoint of progress. + /// The endpoint issueing the notification. + /// The identifying the operation. + /// The progress update to send. + /// A token to cancel the operation. + /// A task representing the completion of the operation. + /// is . + public static Task NotifyProgressAsync( + this IMcpEndpoint endpoint, + ProgressToken progressToken, + ProgressNotificationValue progress, + CancellationToken cancellationToken = default) + { + Throw.IfNull(endpoint); + + return endpoint.SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ProgressNotification, + Params = new ProgressNotification() + { + ProgressToken = progressToken, + Progress = progress, + }, + }, cancellationToken); + } +} diff --git a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs index f33357783..c0cf41977 100644 --- a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs @@ -55,7 +55,7 @@ public class SamplingCapability /// Gets or sets the handler for sampling requests. [JsonIgnore] - public Func>? SamplingHandler { get; set; } + public Func, CancellationToken, Task>? SamplingHandler { get; set; } } /// diff --git a/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs index 419b6fceb..a5500d410 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of prompts and prompt templates the server has. /// See the schema for details /// -public class ListPromptsRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} +public class ListPromptsRequestParams : PaginatedRequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs index f4060dbd0..8a54f6e8e 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of resource templates the server has. /// See the schema for details /// -public class ListResourceTemplatesRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} \ No newline at end of file +public class ListResourceTemplatesRequestParams : PaginatedRequestParams; \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs index ad7f19b31..30bea5b87 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of resources the server has. /// See the schema for details /// -public class ListResourcesRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} +public class ListResourcesRequestParams : PaginatedRequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs index dae1b75c1..a5eec7a15 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs @@ -6,11 +6,4 @@ namespace ModelContextProtocol.Protocol.Types; /// A request from the server to get a list of root URIs from the client. /// See the schema for details /// -public class ListRootsRequestParams -{ - /// - /// Optional progress token for out-of-band progress notifications. - /// - [System.Text.Json.Serialization.JsonPropertyName("progressToken")] - public ProgressToken? ProgressToken { get; init; } -} +public class ListRootsRequestParams : RequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs index 4f18fbb73..64ac18599 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of tools the server has. /// See the schema for details /// -public class ListToolsRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} +public class ListToolsRequestParams : PaginatedRequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs b/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs new file mode 100644 index 000000000..abf47dd3c --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs @@ -0,0 +1,15 @@ +namespace ModelContextProtocol.Protocol.Types; + +/// +/// Used as a base class for paginated requests. +/// See the schema for details +/// +public class PaginatedRequestParams : RequestParams +{ + /// + /// An opaque token representing the current pagination position. + /// If provided, the server should return results starting after this cursor. + /// + [System.Text.Json.Serialization.JsonPropertyName("cursor")] + public string? Cursor { get; init; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/Tool.cs b/src/ModelContextProtocol/Protocol/Types/Tool.cs index ed0c71290..dc0b774c0 100644 --- a/src/ModelContextProtocol/Protocol/Types/Tool.cs +++ b/src/ModelContextProtocol/Protocol/Types/Tool.cs @@ -38,7 +38,7 @@ public JsonElement InputSchema { if (!McpJsonUtilities.IsValidMcpToolSchema(value)) { - throw new ArgumentException("The specified document is not a valid MPC tool JSON schema.", nameof(InputSchema)); + throw new ArgumentException("The specified document is not a valid MCP tool JSON schema.", nameof(InputSchema)); } _inputSchema = value; diff --git a/src/ModelContextProtocol/Server/IMcpServer.cs b/src/ModelContextProtocol/Server/IMcpServer.cs index e8dffaf19..19b3967ad 100644 --- a/src/ModelContextProtocol/Server/IMcpServer.cs +++ b/src/ModelContextProtocol/Server/IMcpServer.cs @@ -1,12 +1,11 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Types; namespace ModelContextProtocol.Server; /// /// Represents a server that can communicate with a client using the MCP protocol. /// -public interface IMcpServer : IAsyncDisposable +public interface IMcpServer : IMcpEndpoint { /// /// Gets the capabilities supported by the client. @@ -26,42 +25,8 @@ public interface IMcpServer : IAsyncDisposable /// IServiceProvider? Services { get; } - /// - /// Adds a handler for client notifications of a specific method. - /// - /// The notification method to handle. - /// The async handler function to process notifications. - /// - /// - /// Each method may have multiple handlers. Adding a handler for a method that already has one - /// will not replace the existing handler. - /// - /// - /// provides constants for common notification methods. - /// - /// - void AddNotificationHandler(string method, Func handler); - /// /// Runs the server, listening for and handling client requests. /// Task RunAsync(CancellationToken cancellationToken = default); - - /// - /// Sends a generic JSON-RPC request to the client. - /// NB! This is a temporary method that is available to send not yet implemented feature messages. - /// Once all MCP features are implemented this will be made private, as it is purely a convenience for those who wish to implement features ahead of the library. - /// - /// The expected response type. - /// The JSON-RPC request to send. - /// A token to cancel the operation. - /// A task containing the client's response. - Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class; - - /// - /// Sends a message to the client. - /// - /// The message. - /// A token to cancel the operation. - Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 8aa32965e..aa35ac339 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -68,7 +68,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? }); SetToolsHandler(options); - SetInitializeHandler(options); SetCompletionHandler(options); SetPingHandler(); diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 3b541ec80..ddaf45f28 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Server; -/// +/// Provides extension methods for interacting with an . public static class McpServerExtensions { /// diff --git a/src/ModelContextProtocol/TokenProgress.cs b/src/ModelContextProtocol/TokenProgress.cs index 46af03f4f..62834e75a 100644 --- a/src/ModelContextProtocol/TokenProgress.cs +++ b/src/ModelContextProtocol/TokenProgress.cs @@ -1,30 +1,16 @@ using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; namespace ModelContextProtocol; /// /// Provides an tied to a specific progress token and that will issue -/// progress notifications to the supplied endpoint. +/// progress notifications on the supplied endpoint. /// -internal sealed class TokenProgress(IMcpServer server, ProgressToken progressToken) : IProgress +internal sealed class TokenProgress(IMcpEndpoint endpoint, ProgressToken progressToken) : IProgress { /// public void Report(ProgressNotificationValue value) { - _ = server.SendMessageAsync(new JsonRpcNotification() - { - Method = NotificationMethods.ProgressNotification, - Params = new ProgressNotification() - { - ProgressToken = progressToken, - Progress = new() - { - Progress = value.Progress, - Total = value.Total, - Message = value.Message, - }, - }, - }, CancellationToken.None); + _ = endpoint.NotifyProgressAsync(progressToken, value, CancellationToken.None); } } diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 9b598d332..a08bbb51d 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -355,7 +355,7 @@ public async Task Sampling_Stdio(string clientId) { Sampling = new() { - SamplingHandler = (_, _) => + SamplingHandler = (_, _, _) => { samplingHandlerCalls++; return Task.FromResult(new CreateMessageResult @@ -511,6 +511,9 @@ public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() [Fact(Skip = "Requires OpenAI API Key", SkipWhen = nameof(NoOpenAIKeySet))] public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() { + var samplingHandler = new OpenAIClient(s_openAIKey) + .AsChatClient("gpt-4o-mini") + .CreateSamplingHandler(); await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, new() { ClientInfo = new() { Name = nameof(SamplingViaChatClient_RequestResponseProperlyPropagated), Version = "1.0.0" }, @@ -518,7 +521,7 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() { Sampling = new() { - SamplingHandler = new OpenAIClient(s_openAIKey).AsChatClient("gpt-4o-mini").CreateSamplingHandler(), + SamplingHandler = samplingHandler, }, }, }, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index a3290049c..bded8e990 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -635,4 +635,47 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella public Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); } + + [Fact] + public async Task NotifyProgress_Should_Be_Handled() + { + await using TestServerTransport transport = new(); + var options = CreateOptions(); + + var notificationReceived = new TaskCompletionSource(); + + var server = McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider); + server.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => + { + notificationReceived.SetResult(notification); + return Task.CompletedTask; + }); + + Task serverTask = server.RunAsync(TestContext.Current.CancellationToken); + + await transport.SendMessageAsync(new JsonRpcNotification + { + Method = NotificationMethods.ProgressNotification, + Params = new ProgressNotification() + { + ProgressToken = new("abc"), + Progress = new() + { + Progress = 50, + Total = 100, + Message = "Progress message", + }, + }, + }, TestContext.Current.CancellationToken); + + var notification = await notificationReceived.Task; + var progress = (ProgressNotification)notification.Params!; + Assert.Equal("\"abc\"", progress.ProgressToken.ToString()); + Assert.Equal(50, progress.Progress.Progress); + Assert.Equal(100, progress.Progress.Total); + Assert.Equal("Progress message", progress.Progress.Message); + + await server.DisposeAsync(); + await serverTask; + } } diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 1874953d9..2d9ba133e 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -116,7 +116,7 @@ public async Task Sampling_Sse_EverythingServer() { Sampling = new() { - SamplingHandler = (_, _) => + SamplingHandler = (_, _, _) => { samplingHandlerCalls++; return Task.FromResult(new CreateMessageResult diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index 44befcd10..4fad0d5bc 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -215,7 +215,7 @@ public async Task Sampling_Sse_TestServer() var options = SseServerIntegrationTestFixture.CreateDefaultClientOptions(); options.Capabilities ??= new(); options.Capabilities.Sampling ??= new(); - options.Capabilities.Sampling.SamplingHandler = async (_, _) => + options.Capabilities.Sampling.SamplingHandler = async (_, _, _) => { samplingHandlerCalls++; return new CreateMessageResult