diff --git a/README.md b/README.md index 163d57f8..e4e6cd9d 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,8 @@ dotnet add package ModelContextProtocol --prerelease ## Getting Started (Client) -To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `IMcpClient` -to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. +To get started writing a client, the `McpClient.CreateAsync` method is used to instantiate and connect an `McpClient` +to a server. Once you have an `McpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. ```csharp var clientTransport = new StdioClientTransport(new StdioClientTransportOptions @@ -48,7 +48,7 @@ var clientTransport = new StdioClientTransport(new StdioClientTransportOptions Arguments = ["-y", "@modelcontextprotocol/server-everything"], }); -var client = await McpClientFactory.CreateAsync(clientTransport); +var client = await McpClient.CreateAsync(clientTransport); // Print the list of tools available from the server. foreach (var tool in await client.ListToolsAsync()) @@ -224,7 +224,7 @@ McpServerOptions options = new() }, }; -await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options); +await using IMcpServer server = McpServer.Create(new StdioServerTransport("MyServer"), options); await server.RunAsync(); ``` diff --git a/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs b/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs index 3ac7f567..e6947745 100644 --- a/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs +++ b/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs @@ -12,7 +12,7 @@ public sealed class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( - IMcpServer thisServer, + McpServer thisServer, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) diff --git a/samples/ChatWithTools/Program.cs b/samples/ChatWithTools/Program.cs index ba597ae8..a84393e1 100644 --- a/samples/ChatWithTools/Program.cs +++ b/samples/ChatWithTools/Program.cs @@ -32,7 +32,7 @@ .UseOpenTelemetry(loggerFactory: loggerFactory, configure: o => o.EnableSensitiveData = true) .Build(); -var mcpClient = await McpClientFactory.CreateAsync( +var mcpClient = await McpClient.CreateAsync( new StdioClientTransport(new() { Command = "npx", diff --git a/samples/EverythingServer/LoggingUpdateMessageSender.cs b/samples/EverythingServer/LoggingUpdateMessageSender.cs index 844aa70d..5f524ad8 100644 --- a/samples/EverythingServer/LoggingUpdateMessageSender.cs +++ b/samples/EverythingServer/LoggingUpdateMessageSender.cs @@ -5,7 +5,7 @@ namespace EverythingServer; -public class LoggingUpdateMessageSender(IMcpServer server, Func getMinLevel) : BackgroundService +public class LoggingUpdateMessageSender(McpServer server, Func getMinLevel) : BackgroundService { readonly Dictionary _loggingLevelMap = new() { diff --git a/samples/EverythingServer/SubscriptionMessageSender.cs b/samples/EverythingServer/SubscriptionMessageSender.cs index 774d9852..b071965d 100644 --- a/samples/EverythingServer/SubscriptionMessageSender.cs +++ b/samples/EverythingServer/SubscriptionMessageSender.cs @@ -2,7 +2,7 @@ using ModelContextProtocol; using ModelContextProtocol.Server; -internal class SubscriptionMessageSender(IMcpServer server, HashSet subscriptions) : BackgroundService +internal class SubscriptionMessageSender(McpServer server, HashSet subscriptions) : BackgroundService { protected override async Task ExecuteAsync(CancellationToken stoppingToken) { diff --git a/samples/EverythingServer/Tools/LongRunningTool.cs b/samples/EverythingServer/Tools/LongRunningTool.cs index 27f6ac20..405b5e82 100644 --- a/samples/EverythingServer/Tools/LongRunningTool.cs +++ b/samples/EverythingServer/Tools/LongRunningTool.cs @@ -10,7 +10,7 @@ public class LongRunningTool { [McpServerTool(Name = "longRunningOperation"), Description("Demonstrates a long running operation with progress updates")] public static async Task LongRunningOperation( - IMcpServer server, + McpServer server, RequestContext context, int duration = 10, int steps = 5) diff --git a/samples/EverythingServer/Tools/SampleLlmTool.cs b/samples/EverythingServer/Tools/SampleLlmTool.cs index a58675c3..6bbe6e51 100644 --- a/samples/EverythingServer/Tools/SampleLlmTool.cs +++ b/samples/EverythingServer/Tools/SampleLlmTool.cs @@ -9,7 +9,7 @@ public class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( - IMcpServer server, + McpServer server, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) diff --git a/samples/InMemoryTransport/Program.cs b/samples/InMemoryTransport/Program.cs index 67e2d320..141692fe 100644 --- a/samples/InMemoryTransport/Program.cs +++ b/samples/InMemoryTransport/Program.cs @@ -6,7 +6,7 @@ Pipe clientToServerPipe = new(), serverToClientPipe = new(); // Create a server using a stream-based transport over an in-memory pipe. -await using IMcpServer server = McpServerFactory.Create( +await using McpServer server = McpServer.Create( new StreamServerTransport(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()), new McpServerOptions() { @@ -21,7 +21,7 @@ _ = server.RunAsync(); // Connect a client using a stream-based transport over the same in-memory pipe. -await using IMcpClient client = await McpClientFactory.CreateAsync( +await using McpClient client = await McpClient.CreateAsync( new StreamClientTransport(clientToServerPipe.Writer.AsStream(), serverToClientPipe.Reader.AsStream())); // List all tools. diff --git a/samples/ProtectedMcpClient/Program.cs b/samples/ProtectedMcpClient/Program.cs index 5871284a..f9c8d4d7 100644 --- a/samples/ProtectedMcpClient/Program.cs +++ b/samples/ProtectedMcpClient/Program.cs @@ -40,7 +40,7 @@ } }, httpClient, consoleLoggerFactory); -var client = await McpClientFactory.CreateAsync(transport, loggerFactory: consoleLoggerFactory); +var client = await McpClient.CreateAsync(transport, loggerFactory: consoleLoggerFactory); var tools = await client.ListToolsAsync(); if (tools.Count == 0) diff --git a/samples/QuickstartClient/Program.cs b/samples/QuickstartClient/Program.cs index d5b887ff..9b8d9edf 100644 --- a/samples/QuickstartClient/Program.cs +++ b/samples/QuickstartClient/Program.cs @@ -33,7 +33,7 @@ Arguments = arguments, }); } -await using var mcpClient = await McpClientFactory.CreateAsync(clientTransport!); +await using var mcpClient = await McpClient.CreateAsync(clientTransport!); var tools = await mcpClient.ListToolsAsync(); foreach (var tool in tools) @@ -62,7 +62,7 @@ var sb = new StringBuilder(); PromptForInput(); -while(Console.ReadLine() is string query && !"exit".Equals(query, StringComparison.OrdinalIgnoreCase)) +while (Console.ReadLine() is string query && !"exit".Equals(query, StringComparison.OrdinalIgnoreCase)) { if (string.IsNullOrWhiteSpace(query)) { diff --git a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs index a096f930..2c96b8c3 100644 --- a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs +++ b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs @@ -12,7 +12,7 @@ public sealed class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( - IMcpServer thisServer, + McpServer thisServer, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 94de9cb9..8d71f516 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -20,7 +20,7 @@ public class HttpServerTransportOptions /// Gets or sets an optional asynchronous callback for running new MCP sessions manually. /// This is useful for running logic before a sessions starts and after it completes. /// - public Func? RunSessionHandler { get; set; } + public Func? RunSessionHandler { get; set; } /// /// Gets or sets whether the server should run in a stateless mode that does not require all requests for a given session diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs index 6ed72fb6..8b3377a6 100644 --- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -54,7 +54,7 @@ public async Task HandleSseRequestAsync(HttpContext context) try { - await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); + await using var mcpServer = McpServer.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); context.Features.Set(mcpServer); var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? StreamableHttpHandler.RunSessionAsync; diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index bfbd805d..5581cff9 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -222,7 +222,7 @@ private async ValueTask CreateSessionAsync( } } - var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, mcpServerServices); + var server = McpServer.Create(transport, mcpServerOptions, loggerFactory, mcpServerServices); context.Features.Set(server); var userIdClaim = statelessId?.UserIdClaim ?? GetUserIdClaim(context.User); @@ -281,7 +281,7 @@ private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttp }; } - internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) + internal static Task RunSessionAsync(HttpContext httpContext, McpServer session, CancellationToken requestAborted) => session.RunAsync(requestAborted); // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs index ffeafada..b98806fa 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.AspNetCore; internal sealed class StreamableHttpSession( string sessionId, StreamableHttpServerTransport transport, - IMcpServer server, + McpServer server, UserIdClaim? userId, StatefulSessionManager sessionManager) : IAsyncDisposable { @@ -20,7 +20,7 @@ internal sealed class StreamableHttpSession( public string Id => sessionId; public StreamableHttpServerTransport Transport => transport; - public IMcpServer Server => server; + public McpServer Server => server; private StatefulSessionManager SessionManager => sessionManager; public CancellationToken SessionClosed => _disposeCts.Token; diff --git a/src/ModelContextProtocol.Core/AssemblyNameHelper.cs b/src/ModelContextProtocol.Core/AssemblyNameHelper.cs new file mode 100644 index 00000000..292ed2f9 --- /dev/null +++ b/src/ModelContextProtocol.Core/AssemblyNameHelper.cs @@ -0,0 +1,9 @@ +using System.Reflection; + +namespace ModelContextProtocol; + +internal static class AssemblyNameHelper +{ + /// Cached naming information used for MCP session name/version when none is specified. + public static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); +} diff --git a/src/ModelContextProtocol.Core/Client/IClientTransport.cs b/src/ModelContextProtocol.Core/Client/IClientTransport.cs index 52517895..2201e9b4 100644 --- a/src/ModelContextProtocol.Core/Client/IClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/IClientTransport.cs @@ -11,7 +11,7 @@ namespace ModelContextProtocol.Client; /// and servers, allowing different transport protocols to be used interchangeably. /// /// -/// When creating an , is typically used, and is +/// When creating an , is typically used, and is /// provided with the based on expected server configuration. /// /// @@ -39,7 +39,7 @@ public interface IClientTransport /// the transport session as well. /// /// - /// This method is used by to initialize the connection. + /// This method is used by to initialize the connection. /// /// /// The transport connection could not be established. diff --git a/src/ModelContextProtocol.Core/Client/IMcpClient.cs b/src/ModelContextProtocol.Core/Client/IMcpClient.cs index 68a92a2d..dad8f282 100644 --- a/src/ModelContextProtocol.Core/Client/IMcpClient.cs +++ b/src/ModelContextProtocol.Core/Client/IMcpClient.cs @@ -1,10 +1,11 @@ -using ModelContextProtocol.Protocol; +using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Client; /// /// Represents an instance of a Model Context Protocol (MCP) client that connects to and communicates with an MCP server. /// +[Obsolete($"Use {nameof(McpClient)} instead.")] public interface IMcpClient : IMcpEndpoint { /// @@ -44,4 +45,4 @@ public interface IMcpClient : IMcpEndpoint /// /// string? ServerInstructions { get; } -} \ No newline at end of file +} diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs index dd8c7fe0..c5178e33 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.cs @@ -1,236 +1,751 @@ +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Runtime.CompilerServices; using System.Text.Json; namespace ModelContextProtocol.Client; -/// -internal sealed partial class McpClient : McpEndpoint, IMcpClient +/// +/// Represents an instance of a Model Context Protocol (MCP) client session that connects to and communicates with an MCP server. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract class McpClient : McpSession, IMcpClient +#pragma warning restore CS0618 // Type or member is obsolete { - private static Implementation DefaultImplementation { get; } = new() + /// + /// Gets the capabilities supported by the connected server. + /// + /// The client is not connected. + public abstract ServerCapabilities ServerCapabilities { get; } + + /// + /// Gets the implementation information of the connected server. + /// + /// + /// + /// This property provides identification details about the connected server, including its name and version. + /// It is populated during the initialization handshake and is available after a successful connection. + /// + /// + /// This information can be useful for logging, debugging, compatibility checks, and displaying server + /// information to users. + /// + /// + /// The client is not connected. + public abstract Implementation ServerInfo { get; } + + /// + /// Gets any instructions describing how to use the connected server and its features. + /// + /// + /// + /// This property contains instructions provided by the server during initialization that explain + /// how to effectively use its capabilities. These instructions can include details about available + /// tools, expected input formats, limitations, or any other helpful information. + /// + /// + /// This can be used by clients to improve an LLM's understanding of available tools, prompts, and resources. + /// It can be thought of like a "hint" to the model and may be added to a system prompt. + /// + /// + public abstract string? ServerInstructions { get; } + + /// Creates an , connecting it to the specified server. + /// The transport instance used to communicate with the server. + /// + /// A client configuration object which specifies client capabilities and protocol version. + /// If , details based on the current process will be employed. + /// + /// A logger factory for creating loggers for clients. + /// The to monitor for cancellation requests. The default is . + /// An that's connected to the specified server. + /// is . + /// is . + public static async Task CreateAsync( + IClientTransport clientTransport, + McpClientOptions? clientOptions = null, + ILoggerFactory? loggerFactory = null, + CancellationToken cancellationToken = default) { - Name = DefaultAssemblyName.Name ?? nameof(McpClient), - Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", - }; + Throw.IfNull(clientTransport); - private readonly IClientTransport _clientTransport; - private readonly McpClientOptions _options; + var transport = await clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + var endpointName = clientTransport.Name; - private ITransport? _sessionTransport; - private CancellationTokenSource? _connectCts; + var clientSession = new McpClientImpl(transport, endpointName, clientOptions, loggerFactory); + try + { + await clientSession.ConnectAsync(cancellationToken).ConfigureAwait(false); + } + catch + { + await clientSession.DisposeAsync().ConfigureAwait(false); + throw; + } - private ServerCapabilities? _serverCapabilities; - private Implementation? _serverInfo; - private string? _serverInstructions; + return clientSession; + } /// - /// Initializes a new instance of the class. + /// Sends a ping request to verify server connectivity. /// - /// The transport to use for communication with the server. - /// Options for the client, defining protocol version and capabilities. - /// The logger factory. - public McpClient(IClientTransport clientTransport, McpClientOptions? options, ILoggerFactory? loggerFactory) - : base(loggerFactory) + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the ping is successful. + /// Thrown when the server cannot be reached or returns an error response. + public Task PingAsync(CancellationToken cancellationToken = default) { - options ??= new(); - - _clientTransport = clientTransport; - _options = options; + var opts = McpJsonUtilities.DefaultOptions; + opts.MakeReadOnly(); + return this.SendRequestAsync( + RequestMethods.Ping, + parameters: null, + serializerOptions: opts, + cancellationToken: cancellationToken).AsTask(); + } - EndpointName = clientTransport.Name; + /// + /// Retrieves a list of available tools from the server. + /// + /// The serializer options governing tool parameter serialization. If null, the default options will be used. + /// The to monitor for cancellation requests. The default is . + /// A list of all available tools as instances. + public async ValueTask> ListToolsAsync( + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); - if (options.Capabilities is { } capabilities) + List? tools = null; + string? cursor = null; + do { - if (capabilities.NotificationHandlers is { } notificationHandlers) + var toolResults = await SendRequestAsync( + RequestMethods.ToolsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + tools ??= new List(toolResults.Tools.Count); + foreach (var tool in toolResults.Tools) { - NotificationHandlers.RegisterRange(notificationHandlers); + tools.Add(new McpClientTool(this, tool, serializerOptions)); } - if (capabilities.Sampling is { } samplingCapability) + cursor = toolResults.NextCursor; + } + while (cursor is not null); + + return tools; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available tools from the server. + /// + /// The serializer options governing tool parameter serialization. If null, the default options will be used. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available tools as instances. + public async IAsyncEnumerable EnumerateToolsAsync( + JsonSerializerOptions? serializerOptions = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + string? cursor = null; + do + { + var toolResults = await SendRequestAsync( + RequestMethods.ToolsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var tool in toolResults.Tools) { - if (samplingCapability.SamplingHandler is not { } samplingHandler) - { - throw new InvalidOperationException("Sampling capability was set but it did not provide a handler."); - } + yield return new McpClientTool(this, tool, serializerOptions); + } + + cursor = toolResults.NextCursor; + } + while (cursor is not null); + } - RequestHandlers.Set( - RequestMethods.SamplingCreateMessage, - (request, _, cancellationToken) => samplingHandler( - request, - request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken), - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult); + /// + /// Retrieves a list of available prompts from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// A list of all available prompts as instances. + public async ValueTask> ListPromptsAsync( + CancellationToken cancellationToken = default) + { + List? prompts = null; + string? cursor = null; + do + { + var promptResults = await SendRequestAsync( + RequestMethods.PromptsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + prompts ??= new List(promptResults.Prompts.Count); + foreach (var prompt in promptResults.Prompts) + { + prompts.Add(new McpClientPrompt(this, prompt)); } - if (capabilities.Roots is { } rootsCapability) + cursor = promptResults.NextCursor; + } + while (cursor is not null); + + return prompts; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available prompts from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available prompts as instances. + public async IAsyncEnumerable EnumeratePromptsAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string? cursor = null; + do + { + var promptResults = await SendRequestAsync( + RequestMethods.PromptsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var prompt in promptResults.Prompts) { - if (rootsCapability.RootsHandler is not { } rootsHandler) - { - throw new InvalidOperationException("Roots capability was set but it did not provide a handler."); - } + yield return new(this, prompt); + } + + cursor = promptResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Retrieves a specific prompt from the MCP server. + /// + /// The name of the prompt to retrieve. + /// Optional arguments for the prompt. Keys are parameter names, and values are the argument values. + /// The serialization options governing argument serialization. + /// The to monitor for cancellation requests. The default is . + /// A task containing the prompt's result with content and messages. + public ValueTask GetPromptAsync( + string name, + IReadOnlyDictionary? arguments = null, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(name); + + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + return SendRequestAsync( + RequestMethods.PromptsGet, + new() { Name = name, Arguments = ToArgumentsDictionary(arguments, serializerOptions) }, + McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, + McpJsonUtilities.JsonContext.Default.GetPromptResult, + cancellationToken: cancellationToken); + } - RequestHandlers.Set( - RequestMethods.RootsList, - (request, _, cancellationToken) => rootsHandler(request, cancellationToken), - McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, - McpJsonUtilities.JsonContext.Default.ListRootsResult); + /// + /// Retrieves a list of available resource templates from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// A list of all available resource templates as instances. + public async ValueTask> ListResourceTemplatesAsync( + CancellationToken cancellationToken = default) + { + List? resourceTemplates = null; + + string? cursor = null; + do + { + var templateResults = await SendRequestAsync( + RequestMethods.ResourcesTemplatesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + resourceTemplates ??= new List(templateResults.ResourceTemplates.Count); + foreach (var template in templateResults.ResourceTemplates) + { + resourceTemplates.Add(new McpClientResourceTemplate(this, template)); } - if (capabilities.Elicitation is { } elicitationCapability) + cursor = templateResults.NextCursor; + } + while (cursor is not null); + + return resourceTemplates; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available resource templates from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available resource templates as instances. + public async IAsyncEnumerable EnumerateResourceTemplatesAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string? cursor = null; + do + { + var templateResults = await SendRequestAsync( + RequestMethods.ResourcesTemplatesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var templateResult in templateResults.ResourceTemplates) { - if (elicitationCapability.ElicitationHandler is not { } elicitationHandler) - { - throw new InvalidOperationException("Elicitation capability was set but it did not provide a handler."); - } + yield return new McpClientResourceTemplate(this, templateResult); + } + + cursor = templateResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Retrieves a list of available resources from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// A list of all available resources as instances. + public async ValueTask> ListResourcesAsync( + CancellationToken cancellationToken = default) + { + List? resources = null; - RequestHandlers.Set( - RequestMethods.ElicitationCreate, - (request, _, cancellationToken) => elicitationHandler(request, cancellationToken), - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.ElicitResult); + string? cursor = null; + do + { + var resourceResults = await SendRequestAsync( + RequestMethods.ResourcesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + resources ??= new List(resourceResults.Resources.Count); + foreach (var resource in resourceResults.Resources) + { + resources.Add(new McpClientResource(this, resource)); } + + cursor = resourceResults.NextCursor; } + while (cursor is not null); + + return resources; } - /// - public string? SessionId + /// + /// Creates an enumerable for asynchronously enumerating all available resources from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available resources as instances. + public async IAsyncEnumerable EnumerateResourcesAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) { - get + string? cursor = null; + do { - if (_sessionTransport is null) + var resourceResults = await SendRequestAsync( + RequestMethods.ResourcesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var resource in resourceResults.Resources) { - throw new InvalidOperationException("Must have already initialized a session when invoking this property."); + yield return new McpClientResource(this, resource); } - return _sessionTransport.SessionId; + cursor = resourceResults.NextCursor; } + while (cursor is not null); } - /// - public ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected."); + /// + /// Reads a resource from the server. + /// + /// The uri of the resource. + /// The to monitor for cancellation requests. The default is . + public ValueTask ReadResourceAsync( + string uri, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uri); + + return SendRequestAsync( + RequestMethods.ResourcesRead, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult, + cancellationToken: cancellationToken); + } - /// - public Implementation ServerInfo => _serverInfo ?? throw new InvalidOperationException("The client is not connected."); + /// + /// Reads a resource from the server. + /// + /// The uri of the resource. + /// The to monitor for cancellation requests. The default is . + public ValueTask ReadResourceAsync( + Uri uri, CancellationToken cancellationToken = default) + { + Throw.IfNull(uri); - /// - public string? ServerInstructions => _serverInstructions; + return ReadResourceAsync(uri.ToString(), cancellationToken); + } - /// - public override string EndpointName { get; } + /// + /// Reads a resource from the server. + /// + /// The uri template of the resource. + /// Arguments to use to format . + /// The to monitor for cancellation requests. The default is . + public ValueTask ReadResourceAsync( + string uriTemplate, IReadOnlyDictionary arguments, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uriTemplate); + Throw.IfNull(arguments); + + return SendRequestAsync( + RequestMethods.ResourcesRead, + new() { Uri = UriTemplate.FormatUri(uriTemplate, arguments) }, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult, + cancellationToken: cancellationToken); + } /// - /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. + /// Requests completion suggestions for a prompt argument or resource reference. /// - public async Task ConnectAsync(CancellationToken cancellationToken = default) + /// The reference object specifying the type and optional URI or name. + /// The name of the argument for which completions are requested. + /// The current value of the argument, used to filter relevant completions. + /// The to monitor for cancellation requests. The default is . + /// A containing completion suggestions. + public ValueTask CompleteAsync(Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) { - _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - cancellationToken = _connectCts.Token; + Throw.IfNull(reference); + Throw.IfNullOrWhiteSpace(argumentName); - try - { - // Connect transport - _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - InitializeSession(_sessionTransport); - // We don't want the ConnectAsync token to cancel the session after we've successfully connected. - // The base class handles cleaning up the session in DisposeAsync without our help. - StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None); + return SendRequestAsync( + RequestMethods.CompletionComplete, + new() + { + Ref = reference, + Argument = new Argument { Name = argumentName, Value = argumentValue } + }, + McpJsonUtilities.JsonContext.Default.CompleteRequestParams, + McpJsonUtilities.JsonContext.Default.CompleteResult, + cancellationToken: cancellationToken); + } + + /// + /// Subscribes to a resource on the server to receive notifications when it changes. + /// + /// The URI of the resource to which to subscribe. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task SubscribeToResourceAsync(string uri, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uri); + + return SendRequestAsync( + RequestMethods.ResourcesSubscribe, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Subscribes to a resource on the server to receive notifications when it changes. + /// + /// The URI of the resource to which to subscribe. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task SubscribeToResourceAsync(Uri uri, CancellationToken cancellationToken = default) + { + Throw.IfNull(uri); + + return SubscribeToResourceAsync(uri.ToString(), cancellationToken); + } + + /// + /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. + /// + /// The URI of the resource to unsubscribe from. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task UnsubscribeFromResourceAsync(string uri, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uri); + + return SendRequestAsync( + RequestMethods.ResourcesUnsubscribe, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. + /// + /// The URI of the resource to unsubscribe from. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task UnsubscribeFromResourceAsync(Uri uri, CancellationToken cancellationToken = default) + { + Throw.IfNull(uri); + + return UnsubscribeFromResourceAsync(uri.ToString(), cancellationToken); + } + + /// + /// Invokes a tool on the server. + /// + /// The name of the tool to call on the server.. + /// An optional dictionary of arguments to pass to the tool. + /// Optional progress reporter for server notifications. + /// JSON serializer options. + /// A cancellation token. + /// The from the tool execution. + public ValueTask CallToolAsync( + string toolName, + IReadOnlyDictionary? arguments = null, + IProgress? progress = null, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + Throw.IfNull(toolName); + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); - // Perform initialization sequence - using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - initializationCts.CancelAfter(_options.InitializationTimeout); + if (progress is not null) + { + return SendRequestWithProgressAsync(toolName, arguments, progress, serializerOptions, cancellationToken); + } - try + return SendRequestAsync( + RequestMethods.ToolsCall, + new() { - // Send initialize request - string requestProtocol = _options.ProtocolVersion ?? McpSession.LatestProtocolVersion; - var initializeResponse = await this.SendRequestAsync( - RequestMethods.Initialize, - new InitializeRequestParams - { - ProtocolVersion = requestProtocol, - Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo ?? DefaultImplementation, - }, - McpJsonUtilities.JsonContext.Default.InitializeRequestParams, - McpJsonUtilities.JsonContext.Default.InitializeResult, - cancellationToken: initializationCts.Token).ConfigureAwait(false); - - // Store server information - if (_logger.IsEnabled(LogLevel.Information)) + Name = toolName, + Arguments = ToArgumentsDictionary(arguments, serializerOptions), + }, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResult, + cancellationToken: cancellationToken); + + async ValueTask SendRequestWithProgressAsync( + string toolName, + IReadOnlyDictionary? arguments, + IProgress progress, + JsonSerializerOptions serializerOptions, + CancellationToken cancellationToken) + { + ProgressToken progressToken = new(Guid.NewGuid().ToString("N")); + + await using var _ = RegisterNotificationHandler(NotificationMethods.ProgressNotification, + (notification, cancellationToken) => { - LogServerCapabilitiesReceived(EndpointName, - capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), - serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); - } + if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotificationParams) is { } pn && + pn.ProgressToken == progressToken) + { + progress.Report(pn.Progress); + } - _serverCapabilities = initializeResponse.Capabilities; - _serverInfo = initializeResponse.ServerInfo; - _serverInstructions = initializeResponse.Instructions; + return default; + }).ConfigureAwait(false); - // Validate protocol version - bool isResponseProtocolValid = - _options.ProtocolVersion is { } optionsProtocol ? optionsProtocol == initializeResponse.ProtocolVersion : - McpSession.SupportedProtocolVersions.Contains(initializeResponse.ProtocolVersion); - if (!isResponseProtocolValid) + return await SendRequestAsync( + RequestMethods.ToolsCall, + new() { - LogServerProtocolVersionMismatch(EndpointName, requestProtocol, initializeResponse.ProtocolVersion); - throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}"); - } + Name = toolName, + Arguments = ToArgumentsDictionary(arguments, serializerOptions), + ProgressToken = progressToken, + }, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Converts the contents of a into a pair of + /// and instances to use + /// as inputs into a operation. + /// + /// + /// The created pair of messages and options. + /// is . + internal static (IList Messages, ChatOptions? Options) ToChatClientArguments( + CreateMessageRequestParams requestParams) + { + Throw.IfNull(requestParams); - // Send initialized notification - await this.SendNotificationAsync( - NotificationMethods.InitializedNotification, - new InitializedNotificationParams(), - McpJsonUtilities.JsonContext.Default.InitializedNotificationParams, - cancellationToken: initializationCts.Token).ConfigureAwait(false); + ChatOptions? options = null; - } - catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) - { - LogClientInitializationTimeout(EndpointName); - throw new TimeoutException("Initialization timed out", oce); - } + if (requestParams.MaxTokens is int maxTokens) + { + (options ??= new()).MaxOutputTokens = maxTokens; } - catch (Exception e) + + if (requestParams.Temperature is float temperature) { - LogClientInitializationError(EndpointName, e); - await DisposeAsync().ConfigureAwait(false); - throw; + (options ??= new()).Temperature = temperature; + } + + if (requestParams.StopSequences is { } stopSequences) + { + (options ??= new()).StopSequences = stopSequences.ToArray(); } + + List messages = + (from sm in requestParams.Messages + let aiContent = sm.Content.ToAIContent() + where aiContent is not null + select new ChatMessage(sm.Role == Role.Assistant ? ChatRole.Assistant : ChatRole.User, [aiContent])) + .ToList(); + + return (messages, options); } - /// - public override async ValueTask DisposeUnsynchronizedAsync() + /// Converts the contents of a into a . + /// The whose contents should be extracted. + /// The created . + /// is . + internal static CreateMessageResult ToCreateMessageResult(ChatResponse chatResponse) { - try + Throw.IfNull(chatResponse); + + // The ChatResponse can include multiple messages, of varying modalities, but CreateMessageResult supports + // only either a single blob of text or a single image. Heuristically, we'll use an image if there is one + // in any of the response messages, or we'll use all the text from them concatenated, otherwise. + + ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault(); + + ContentBlock? content = null; + if (lastMessage is not null) { - if (_connectCts is not null) + foreach (var lmc in lastMessage.Contents) { - await _connectCts.CancelAsync().ConfigureAwait(false); - _connectCts.Dispose(); + if (lmc is DataContent dc && (dc.HasTopLevelMediaType("image") || dc.HasTopLevelMediaType("audio"))) + { + content = dc.ToContent(); + } } - - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); } - finally + + return new() + { + Content = content ?? new TextContentBlock { Text = lastMessage?.Text ?? string.Empty }, + Model = chatResponse.ModelId ?? "unknown", + Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant, + StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", + }; + } + + /// + /// Creates a sampling handler for use with that will + /// satisfy sampling requests using the specified . + /// + /// The with which to satisfy sampling requests. + /// The created handler delegate that can be assigned to . + /// is . + public static Func, CancellationToken, ValueTask> CreateSamplingHandler( + IChatClient chatClient) + { + Throw.IfNull(chatClient); + + return async (requestParams, progress, cancellationToken) => { - if (_sessionTransport is not null) + Throw.IfNull(requestParams); + + var (messages, options) = ToChatClientArguments(requestParams); + var progressToken = requestParams.ProgressToken; + + List updates = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { - await _sessionTransport.DisposeAsync().ConfigureAwait(false); + updates.Add(update); + + if (progressToken is not null) + { + progress.Report(new() + { + Progress = updates.Count, + }); + } } - } + + return ToCreateMessageResult(updates.ToChatResponse()); + }; } - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] - private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); + /// + /// Sets the logging level for the server to control which log messages are sent to the client. + /// + /// The minimum severity level of log messages to receive from the server. + /// The to monitor for cancellation requests. The default is . + /// A task representing the asynchronous operation. + public Task SetLoggingLevel(LoggingLevel level, CancellationToken cancellationToken = default) + { + return SendRequestAsync( + RequestMethods.LoggingSetLevel, + new() { Level = level }, + McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); + } - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization error.")] - private partial void LogClientInitializationError(string endpointName, Exception exception); + /// + /// Sets the logging level for the server to control which log messages are sent to the client. + /// + /// The minimum severity level of log messages to receive from the server. + /// The to monitor for cancellation requests. The default is . + /// A task representing the asynchronous operation. + public Task SetLoggingLevel(LogLevel level, CancellationToken cancellationToken = default) => + SetLoggingLevel(McpServerImpl.ToLoggingLevel(level), cancellationToken); + + /// Convers a dictionary with values to a dictionary with values. + private static Dictionary? ToArgumentsDictionary( + IReadOnlyDictionary? arguments, JsonSerializerOptions options) + { + var typeInfo = options.GetTypeInfo(); - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization timed out.")] - private partial void LogClientInitializationTimeout(string endpointName); + Dictionary? result = null; + if (arguments is not null) + { + result = new(arguments.Count); + foreach (var kvp in arguments) + { + result.Add(kvp.Key, kvp.Value is JsonElement je ? je : JsonSerializer.SerializeToElement(kvp.Value, typeInfo)); + } + } - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client protocol version mismatch with server. Expected '{Expected}', received '{Received}'.")] - private partial void LogServerProtocolVersionMismatch(string endpointName, string expected, string received); -} \ No newline at end of file + return result; + } +} diff --git a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs index 60a9c3a6..9ceea460 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs @@ -1,14 +1,13 @@ using Microsoft.Extensions.AI; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Text.Json; namespace ModelContextProtocol.Client; /// -/// Provides extension methods for interacting with an . +/// Provides extension methods for interacting with an . /// /// /// @@ -19,6 +18,128 @@ namespace ModelContextProtocol.Client; /// public static class McpClientExtensions { + /// + /// Converts the contents of a into a pair of + /// and instances to use + /// as inputs into a operation. + /// + /// + /// The created pair of messages and options. + /// is . + internal static (IList Messages, ChatOptions? Options) ToChatClientArguments( + this CreateMessageRequestParams requestParams) + { + Throw.IfNull(requestParams); + + ChatOptions? options = null; + + if (requestParams.MaxTokens is int maxTokens) + { + (options ??= new()).MaxOutputTokens = maxTokens; + } + + if (requestParams.Temperature is float temperature) + { + (options ??= new()).Temperature = temperature; + } + + if (requestParams.StopSequences is { } stopSequences) + { + (options ??= new()).StopSequences = stopSequences.ToArray(); + } + + List messages = + (from sm in requestParams.Messages + let aiContent = sm.Content.ToAIContent() + where aiContent is not null + select new ChatMessage(sm.Role == Role.Assistant ? ChatRole.Assistant : ChatRole.User, [aiContent])) + .ToList(); + + return (messages, options); + } + + /// Converts the contents of a into a . + /// The whose contents should be extracted. + /// The created . + /// is . + internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chatResponse) + { + Throw.IfNull(chatResponse); + + // The ChatResponse can include multiple messages, of varying modalities, but CreateMessageResult supports + // only either a single blob of text or a single image. Heuristically, we'll use an image if there is one + // in any of the response messages, or we'll use all the text from them concatenated, otherwise. + + ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault(); + + ContentBlock? content = null; + if (lastMessage is not null) + { + foreach (var lmc in lastMessage.Contents) + { + if (lmc is DataContent dc && (dc.HasTopLevelMediaType("image") || dc.HasTopLevelMediaType("audio"))) + { + content = dc.ToContent(); + } + } + } + + return new() + { + Content = content ?? new TextContentBlock { Text = lastMessage?.Text ?? string.Empty }, + Model = chatResponse.ModelId ?? "unknown", + Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant, + StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", + }; + } + + /// + /// Creates a sampling handler for use with that will + /// satisfy sampling requests using the specified . + /// + /// The with which to satisfy sampling requests. + /// The created handler delegate that can be assigned to . + /// + /// + /// This method creates a function that converts MCP message requests into chat client calls, enabling + /// an MCP client to generate text or other content using an actual AI model via the provided chat client. + /// + /// + /// The handler can process text messages, image messages, and resource messages as defined in the + /// Model Context Protocol. + /// + /// + /// is . + public static Func, CancellationToken, ValueTask> CreateSamplingHandler( + this IChatClient chatClient) + { + Throw.IfNull(chatClient); + + return async (requestParams, progress, cancellationToken) => + { + Throw.IfNull(requestParams); + + var (messages, options) = requestParams.ToChatClientArguments(); + var progressToken = requestParams.ProgressToken; + + List updates = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + updates.Add(update); + + if (progressToken is not null) + { + progress.Report(new() + { + Progress = updates.Count, + }); + } + } + + return updates.ToChatResponse().ToCreateMessageResult(); + }; + } + /// /// Sends a ping request to verify server connectivity. /// @@ -38,17 +159,9 @@ public static class McpClientExtensions /// /// is . /// Thrown when the server cannot be reached or returns an error response. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.PingAsync)} instead.")] public static Task PingAsync(this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - return client.SendRequestAsync( - RequestMethods.Ping, - parameters: null, - McpJsonUtilities.JsonContext.Default.Object!, - McpJsonUtilities.JsonContext.Default.Object, - cancellationToken: cancellationToken).AsTask(); - } + => AsClientOrThrow(client).PingAsync(cancellationToken); /// /// Retrieves a list of available tools from the server. @@ -89,39 +202,12 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat /// /// /// is . - public static async ValueTask> ListToolsAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListToolsAsync)} instead.")] + public static ValueTask> ListToolsAsync( this IMcpClient client, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - List? tools = null; - string? cursor = null; - do - { - var toolResults = await client.SendRequestAsync( - RequestMethods.ToolsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, - McpJsonUtilities.JsonContext.Default.ListToolsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - tools ??= new List(toolResults.Tools.Count); - foreach (var tool in toolResults.Tools) - { - tools.Add(new McpClientTool(client, tool, serializerOptions)); - } - - cursor = toolResults.NextCursor; - } - while (cursor is not null); - - return tools; - } + => AsClientOrThrow(client).ListToolsAsync(serializerOptions, cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available tools from the server. @@ -155,35 +241,12 @@ public static async ValueTask> ListToolsAsync( /// /// /// is . - public static async IAsyncEnumerable EnumerateToolsAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumerateToolsAsync)} instead.")] + public static IAsyncEnumerable EnumerateToolsAsync( this IMcpClient client, JsonSerializerOptions? serializerOptions = null, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - string? cursor = null; - do - { - var toolResults = await client.SendRequestAsync( - RequestMethods.ToolsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, - McpJsonUtilities.JsonContext.Default.ListToolsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var tool in toolResults.Tools) - { - yield return new McpClientTool(client, tool, serializerOptions); - } - - cursor = toolResults.NextCursor; - } - while (cursor is not null); - } + CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumerateToolsAsync(serializerOptions, cancellationToken); /// /// Retrieves a list of available prompts from the server. @@ -202,34 +265,10 @@ public static async IAsyncEnumerable EnumerateToolsAsync( /// /// /// is . - public static async ValueTask> ListPromptsAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListPromptsAsync)} instead.")] + public static ValueTask> ListPromptsAsync( this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List? prompts = null; - string? cursor = null; - do - { - var promptResults = await client.SendRequestAsync( - RequestMethods.PromptsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, - McpJsonUtilities.JsonContext.Default.ListPromptsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - prompts ??= new List(promptResults.Prompts.Count); - foreach (var prompt in promptResults.Prompts) - { - prompts.Add(new McpClientPrompt(client, prompt)); - } - - cursor = promptResults.NextCursor; - } - while (cursor is not null); - - return prompts; - } + => AsClientOrThrow(client).ListPromptsAsync(cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available prompts from the server. @@ -258,30 +297,10 @@ public static async ValueTask> ListPromptsAsync( /// /// /// is . - public static async IAsyncEnumerable EnumeratePromptsAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - string? cursor = null; - do - { - var promptResults = await client.SendRequestAsync( - RequestMethods.PromptsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, - McpJsonUtilities.JsonContext.Default.ListPromptsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var prompt in promptResults.Prompts) - { - yield return new(client, prompt); - } - - cursor = promptResults.NextCursor; - } - while (cursor is not null); - } + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumeratePromptsAsync)} instead.")] + public static IAsyncEnumerable EnumeratePromptsAsync( + this IMcpClient client, CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumeratePromptsAsync(cancellationToken); /// /// Retrieves a specific prompt from the MCP server. @@ -308,26 +327,14 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( /// /// Thrown when the prompt does not exist, when required arguments are missing, or when the server encounters an error processing the prompt. /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.GetPromptAsync)} instead.")] public static ValueTask GetPromptAsync( this IMcpClient client, string name, IReadOnlyDictionary? arguments = null, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(name); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - return client.SendRequestAsync( - RequestMethods.PromptsGet, - new() { Name = name, Arguments = ToArgumentsDictionary(arguments, serializerOptions) }, - McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, - McpJsonUtilities.JsonContext.Default.GetPromptResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).GetPromptAsync(name, arguments, serializerOptions, cancellationToken); /// /// Retrieves a list of available resource templates from the server. @@ -346,35 +353,10 @@ public static ValueTask GetPromptAsync( /// /// /// is . - public static async ValueTask> ListResourceTemplatesAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListResourceTemplatesAsync)} instead.")] + public static ValueTask> ListResourceTemplatesAsync( this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List? resourceTemplates = null; - - string? cursor = null; - do - { - var templateResults = await client.SendRequestAsync( - RequestMethods.ResourcesTemplatesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - resourceTemplates ??= new List(templateResults.ResourceTemplates.Count); - foreach (var template in templateResults.ResourceTemplates) - { - resourceTemplates.Add(new McpClientResourceTemplate(client, template)); - } - - cursor = templateResults.NextCursor; - } - while (cursor is not null); - - return resourceTemplates; - } + => AsClientOrThrow(client).ListResourceTemplatesAsync(cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available resource templates from the server. @@ -403,30 +385,10 @@ public static async ValueTask> ListResourceTemp /// /// /// is . - public static async IAsyncEnumerable EnumerateResourceTemplatesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - string? cursor = null; - do - { - var templateResults = await client.SendRequestAsync( - RequestMethods.ResourcesTemplatesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var templateResult in templateResults.ResourceTemplates) - { - yield return new McpClientResourceTemplate(client, templateResult); - } - - cursor = templateResults.NextCursor; - } - while (cursor is not null); - } + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumerateResourceTemplatesAsync)} instead.")] + public static IAsyncEnumerable EnumerateResourceTemplatesAsync( + this IMcpClient client, CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumerateResourceTemplatesAsync(cancellationToken); /// /// Retrieves a list of available resources from the server. @@ -457,35 +419,10 @@ public static async IAsyncEnumerable EnumerateResourc /// /// /// is . - public static async ValueTask> ListResourcesAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListResourcesAsync)} instead.")] + public static ValueTask> ListResourcesAsync( this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List? resources = null; - - string? cursor = null; - do - { - var resourceResults = await client.SendRequestAsync( - RequestMethods.ResourcesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourcesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - resources ??= new List(resourceResults.Resources.Count); - foreach (var resource in resourceResults.Resources) - { - resources.Add(new McpClientResource(client, resource)); - } - - cursor = resourceResults.NextCursor; - } - while (cursor is not null); - - return resources; - } + => AsClientOrThrow(client).ListResourcesAsync(cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available resources from the server. @@ -514,30 +451,10 @@ public static async ValueTask> ListResourcesAsync( /// /// /// is . - public static async IAsyncEnumerable EnumerateResourcesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - string? cursor = null; - do - { - var resourceResults = await client.SendRequestAsync( - RequestMethods.ResourcesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourcesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var resource in resourceResults.Resources) - { - yield return new McpClientResource(client, resource); - } - - cursor = resourceResults.NextCursor; - } - while (cursor is not null); - } + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumerateResourcesAsync)} instead.")] + public static IAsyncEnumerable EnumerateResourcesAsync( + this IMcpClient client, CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumerateResourcesAsync(cancellationToken); /// /// Reads a resource from the server. @@ -548,19 +465,10 @@ public static async IAsyncEnumerable EnumerateResourcesAsync( /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ReadResourceAsync)} instead.")] public static ValueTask ReadResourceAsync( this IMcpClient client, string uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uri); - - return client.SendRequestAsync( - RequestMethods.ResourcesRead, - new() { Uri = uri }, - McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, - McpJsonUtilities.JsonContext.Default.ReadResourceResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).ReadResourceAsync(uri, cancellationToken); /// /// Reads a resource from the server. @@ -570,14 +478,10 @@ public static ValueTask ReadResourceAsync( /// The to monitor for cancellation requests. The default is . /// is . /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ReadResourceAsync)} instead.")] public static ValueTask ReadResourceAsync( this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(uri); - - return ReadResourceAsync(client, uri.ToString(), cancellationToken); - } + => AsClientOrThrow(client).ReadResourceAsync(uri, cancellationToken); /// /// Reads a resource from the server. @@ -589,20 +493,10 @@ public static ValueTask ReadResourceAsync( /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ReadResourceAsync)} instead.")] public static ValueTask ReadResourceAsync( this IMcpClient client, string uriTemplate, IReadOnlyDictionary arguments, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uriTemplate); - Throw.IfNull(arguments); - - return client.SendRequestAsync( - RequestMethods.ResourcesRead, - new() { Uri = UriTemplate.FormatUri(uriTemplate, arguments) }, - McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, - McpJsonUtilities.JsonContext.Default.ReadResourceResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).ReadResourceAsync(uriTemplate, arguments, cancellationToken); /// /// Requests completion suggestions for a prompt argument or resource reference. @@ -633,23 +527,9 @@ public static ValueTask ReadResourceAsync( /// is . /// is empty or composed entirely of whitespace. /// The server returned an error response. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.CompleteAsync)} instead.")] public static ValueTask CompleteAsync(this IMcpClient client, Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(reference); - Throw.IfNullOrWhiteSpace(argumentName); - - return client.SendRequestAsync( - RequestMethods.CompletionComplete, - new() - { - Ref = reference, - Argument = new Argument { Name = argumentName, Value = argumentValue } - }, - McpJsonUtilities.JsonContext.Default.CompleteRequestParams, - McpJsonUtilities.JsonContext.Default.CompleteResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).CompleteAsync(reference, argumentName, argumentValue, cancellationToken); /// /// Subscribes to a resource on the server to receive notifications when it changes. @@ -676,18 +556,9 @@ public static ValueTask CompleteAsync(this IMcpClient client, Re /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.SubscribeToResourceAsync)} instead.")] public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uri); - - return client.SendRequestAsync( - RequestMethods.ResourcesSubscribe, - new() { Uri = uri }, - McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); - } + => AsClientOrThrow(client).SubscribeToResourceAsync(uri, cancellationToken); /// /// Subscribes to a resource on the server to receive notifications when it changes. @@ -713,13 +584,9 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, /// /// is . /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.SubscribeToResourceAsync)} instead.")] public static Task SubscribeToResourceAsync(this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(uri); - - return SubscribeToResourceAsync(client, uri.ToString(), cancellationToken); - } + => AsClientOrThrow(client).SubscribeToResourceAsync(uri, cancellationToken); /// /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. @@ -745,18 +612,9 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, Uri uri, Can /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.UnsubscribeFromResourceAsync)} instead.")] public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uri); - - return client.SendRequestAsync( - RequestMethods.ResourcesUnsubscribe, - new() { Uri = uri }, - McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); - } + => AsClientOrThrow(client).UnsubscribeFromResourceAsync(uri, cancellationToken); /// /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. @@ -781,13 +639,9 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u /// /// is . /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.UnsubscribeFromResourceAsync)} instead.")] public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(uri); - - return UnsubscribeFromResourceAsync(client, uri.ToString(), cancellationToken); - } + => AsClientOrThrow(client).UnsubscribeFromResourceAsync(uri, cancellationToken); /// /// Invokes a tool on the server. @@ -824,6 +678,7 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri, /// }); /// /// + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.CallToolAsync)} instead.")] public static ValueTask CallToolAsync( this IMcpClient client, string toolName, @@ -831,264 +686,27 @@ public static ValueTask CallToolAsync( IProgress? progress = null, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(toolName); - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - if (progress is not null) - { - return SendRequestWithProgressAsync(client, toolName, arguments, progress, serializerOptions, cancellationToken); - } - - return client.SendRequestAsync( - RequestMethods.ToolsCall, - new() - { - Name = toolName, - Arguments = ToArgumentsDictionary(arguments, serializerOptions), - }, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult, - cancellationToken: cancellationToken); - - static async ValueTask SendRequestWithProgressAsync( - IMcpClient client, - string toolName, - IReadOnlyDictionary? arguments, - IProgress progress, - JsonSerializerOptions serializerOptions, - CancellationToken cancellationToken) - { - ProgressToken progressToken = new(Guid.NewGuid().ToString("N")); - - await using var _ = client.RegisterNotificationHandler(NotificationMethods.ProgressNotification, - (notification, cancellationToken) => - { - if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotificationParams) is { } pn && - pn.ProgressToken == progressToken) - { - progress.Report(pn.Progress); - } - - return default; - }).ConfigureAwait(false); - - return await client.SendRequestAsync( - RequestMethods.ToolsCall, - new() - { - Name = toolName, - Arguments = ToArgumentsDictionary(arguments, serializerOptions), - ProgressToken = progressToken, - }, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - } - } - - /// - /// Converts the contents of a into a pair of - /// and instances to use - /// as inputs into a operation. - /// - /// - /// The created pair of messages and options. - /// is . - internal static (IList Messages, ChatOptions? Options) ToChatClientArguments( - this CreateMessageRequestParams requestParams) - { - Throw.IfNull(requestParams); - - ChatOptions? options = null; + => AsClientOrThrow(client).CallToolAsync(toolName, arguments, progress, serializerOptions, cancellationToken); - if (requestParams.MaxTokens is int maxTokens) - { - (options ??= new()).MaxOutputTokens = maxTokens; - } - - if (requestParams.Temperature is float temperature) - { - (options ??= new()).Temperature = temperature; - } - - if (requestParams.StopSequences is { } stopSequences) - { - (options ??= new()).StopSequences = stopSequences.ToArray(); - } - - List messages = - (from sm in requestParams.Messages - let aiContent = sm.Content.ToAIContent() - where aiContent is not null - select new ChatMessage(sm.Role == Role.Assistant ? ChatRole.Assistant : ChatRole.User, [aiContent])) - .ToList(); - - return (messages, options); - } - - /// Converts the contents of a into a . - /// The whose contents should be extracted. - /// The created . - /// is . - internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chatResponse) + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#pragma warning disable CS0618 // Type or member is obsolete + private static McpClient AsClientOrThrow(IMcpClient client, [CallerMemberName] string memberName = "") +#pragma warning restore CS0618 // Type or member is obsolete { - Throw.IfNull(chatResponse); - - // The ChatResponse can include multiple messages, of varying modalities, but CreateMessageResult supports - // only either a single blob of text or a single image. Heuristically, we'll use an image if there is one - // in any of the response messages, or we'll use all the text from them concatenated, otherwise. - - ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault(); - - ContentBlock? content = null; - if (lastMessage is not null) + if (client is not McpClient mcpClient) { - foreach (var lmc in lastMessage.Contents) - { - if (lmc is DataContent dc && (dc.HasTopLevelMediaType("image") || dc.HasTopLevelMediaType("audio"))) - { - content = dc.ToContent(); - } - } + ThrowInvalidEndpointType(memberName); } - return new() - { - Content = content ?? new TextContentBlock { Text = lastMessage?.Text ?? string.Empty }, - Model = chatResponse.ModelId ?? "unknown", - Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant, - StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", - }; - } - - /// - /// Creates a sampling handler for use with that will - /// satisfy sampling requests using the specified . - /// - /// The with which to satisfy sampling requests. - /// The created handler delegate that can be assigned to . - /// - /// - /// This method creates a function that converts MCP message requests into chat client calls, enabling - /// an MCP client to generate text or other content using an actual AI model via the provided chat client. - /// - /// - /// The handler can process text messages, image messages, and resource messages as defined in the - /// Model Context Protocol. - /// - /// - /// is . - public static Func, CancellationToken, ValueTask> CreateSamplingHandler( - this IChatClient chatClient) - { - Throw.IfNull(chatClient); - - return async (requestParams, progress, cancellationToken) => - { - Throw.IfNull(requestParams); - - var (messages, options) = requestParams.ToChatClientArguments(); - var progressToken = requestParams.ProgressToken; - - List updates = []; - await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) - { - updates.Add(update); - - if (progressToken is not null) - { - progress.Report(new() - { - Progress = updates.Count, - }); - } - } - - return updates.ToChatResponse().ToCreateMessageResult(); - }; - } - - /// - /// Sets the logging level for the server to control which log messages are sent to the client. - /// - /// The client instance used to communicate with the MCP server. - /// The minimum severity level of log messages to receive from the server. - /// The to monitor for cancellation requests. The default is . - /// A task representing the asynchronous operation. - /// - /// - /// After this request is processed, the server will send log messages at or above the specified - /// logging level as notifications to the client. For example, if is set, - /// the client will receive , , - /// , , and - /// level messages. - /// - /// - /// To receive all log messages, set the level to . - /// - /// - /// Log messages are delivered as notifications to the client and can be captured by registering - /// appropriate event handlers with the client implementation, such as with . - /// - /// - /// is . - public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - return client.SendRequestAsync( - RequestMethods.LoggingSetLevel, - new() { Level = level }, - McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); - } - - /// - /// Sets the logging level for the server to control which log messages are sent to the client. - /// - /// The client instance used to communicate with the MCP server. - /// The minimum severity level of log messages to receive from the server. - /// The to monitor for cancellation requests. The default is . - /// A task representing the asynchronous operation. - /// - /// - /// After this request is processed, the server will send log messages at or above the specified - /// logging level as notifications to the client. For example, if is set, - /// the client will receive , , - /// and level messages. - /// - /// - /// To receive all log messages, set the level to . - /// - /// - /// Log messages are delivered as notifications to the client and can be captured by registering - /// appropriate event handlers with the client implementation, such as with . - /// - /// - /// is . - public static Task SetLoggingLevel(this IMcpClient client, LogLevel level, CancellationToken cancellationToken = default) => - SetLoggingLevel(client, McpServer.ToLoggingLevel(level), cancellationToken); - - /// Convers a dictionary with values to a dictionary with values. - private static Dictionary? ToArgumentsDictionary( - IReadOnlyDictionary? arguments, JsonSerializerOptions options) - { - var typeInfo = options.GetTypeInfo(); - - Dictionary? result = null; - if (arguments is not null) - { - result = new(arguments.Count); - foreach (var kvp in arguments) - { - result.Add(kvp.Key, kvp.Value is JsonElement je ? je : JsonSerializer.SerializeToElement(kvp.Value, typeInfo)); - } - } + return mcpClient; - return result; + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowInvalidEndpointType(string memberName) + => throw new InvalidOperationException( + $"Only arguments assignable to '{nameof(McpClient)}' are supported. " + + $"Prefer using '{nameof(McpClient)}.{memberName}' instead, as " + + $"'{nameof(McpClientExtensions)}.{memberName}' is obsolete and will be " + + $"removed in the future."); } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientFactory.cs b/src/ModelContextProtocol.Core/Client/McpClientFactory.cs index 30b3a947..756281a1 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientFactory.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; namespace ModelContextProtocol.Client; @@ -10,6 +10,7 @@ namespace ModelContextProtocol.Client; /// that connect to MCP servers. It handles the creation and connection /// of appropriate implementations through the supplied transport. /// +[Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.CreateAsync)} instead.")] public static partial class McpClientFactory { /// Creates an , connecting it to the specified server. @@ -28,27 +29,5 @@ public static async Task CreateAsync( McpClientOptions? clientOptions = null, ILoggerFactory? loggerFactory = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(clientTransport); - - McpClient client = new(clientTransport, clientOptions, loggerFactory); - try - { - await client.ConnectAsync(cancellationToken).ConfigureAwait(false); - if (loggerFactory?.CreateLogger(typeof(McpClientFactory)) is ILogger logger) - { - logger.LogClientCreated(client.EndpointName); - } - } - catch - { - await client.DisposeAsync().ConfigureAwait(false); - throw; - } - - return client; - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")] - private static partial void LogClientCreated(this ILogger logger, string endpointName); + => await McpClient.CreateAsync(clientTransport, clientOptions, loggerFactory, cancellationToken); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs new file mode 100644 index 00000000..4a1e1397 --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -0,0 +1,241 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Text.Json; + +namespace ModelContextProtocol.Client; + +/// +internal sealed partial class McpClientImpl : McpClient +{ + private static Implementation DefaultImplementation { get; } = new() + { + Name = AssemblyNameHelper.DefaultAssemblyName.Name ?? nameof(McpClient), + Version = AssemblyNameHelper.DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + + private readonly ILogger _logger; + private readonly ITransport _transport; + private readonly string _endpointName; + private readonly McpClientOptions _options; + private readonly McpSessionHandler _sessionHandler; + + private CancellationTokenSource? _connectCts; + + private ServerCapabilities? _serverCapabilities; + private Implementation? _serverInfo; + private string? _serverInstructions; + + private int _isDisposed; + + /// + /// Initializes a new instance of the class. + /// + /// The transport to use for communication with the server. + /// The name of the endpoint for logging and debug purposes. + /// Options for the client, defining protocol version and capabilities. + /// The logger factory. + internal McpClientImpl(ITransport transport, string endpointName, McpClientOptions? options, ILoggerFactory? loggerFactory) + { + options ??= new(); + + _transport = transport; + _endpointName = $"Client ({options.ClientInfo?.Name ?? DefaultImplementation.Name} {options.ClientInfo?.Version ?? DefaultImplementation.Version})"; + _options = options; + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + var notificationHandlers = new NotificationHandlers(); + var requestHandlers = new RequestHandlers(); + + if (options.Capabilities is { } capabilities) + { + RegisterHandlers(capabilities, notificationHandlers, requestHandlers); + } + + _sessionHandler = new McpSessionHandler(isServer: false, transport, endpointName, requestHandlers, notificationHandlers, _logger); + } + + private void RegisterHandlers(ClientCapabilities capabilities, NotificationHandlers notificationHandlers, RequestHandlers requestHandlers) + { + if (capabilities.NotificationHandlers is { } notificationHandlersFromCapabilities) + { + notificationHandlers.RegisterRange(notificationHandlersFromCapabilities); + } + + if (capabilities.Sampling is { } samplingCapability) + { + if (samplingCapability.SamplingHandler is not { } samplingHandler) + { + throw new InvalidOperationException("Sampling capability was set but it did not provide a handler."); + } + + requestHandlers.Set( + RequestMethods.SamplingCreateMessage, + (request, _, cancellationToken) => samplingHandler( + request, + request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken), + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult); + } + + if (capabilities.Roots is { } rootsCapability) + { + if (rootsCapability.RootsHandler is not { } rootsHandler) + { + throw new InvalidOperationException("Roots capability was set but it did not provide a handler."); + } + + requestHandlers.Set( + RequestMethods.RootsList, + (request, _, cancellationToken) => rootsHandler(request, cancellationToken), + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult); + } + + if (capabilities.Elicitation is { } elicitationCapability) + { + if (elicitationCapability.ElicitationHandler is not { } elicitationHandler) + { + throw new InvalidOperationException("Elicitation capability was set but it did not provide a handler."); + } + + requestHandlers.Set( + RequestMethods.ElicitationCreate, + (request, _, cancellationToken) => elicitationHandler(request, cancellationToken), + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult); + } + } + + /// + public override string? SessionId => _transport.SessionId; + + /// + public override ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected."); + + /// + public override Implementation ServerInfo => _serverInfo ?? throw new InvalidOperationException("The client is not connected."); + + /// + public override string? ServerInstructions => _serverInstructions; + + /// + /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) + { + _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cancellationToken = _connectCts.Token; + + try + { + // We don't want the ConnectAsync token to cancel the message processing loop after we've successfully connected. + // The session handler handles cancelling the loop upon its disposal. + _ = _sessionHandler.ProcessMessagesAsync(CancellationToken.None); + + // Perform initialization sequence + using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + initializationCts.CancelAfter(_options.InitializationTimeout); + + try + { + // Send initialize request + string requestProtocol = _options.ProtocolVersion ?? McpSessionHandler.LatestProtocolVersion; + var initializeResponse = await this.SendRequestAsync( + RequestMethods.Initialize, + new InitializeRequestParams + { + ProtocolVersion = requestProtocol, + Capabilities = _options.Capabilities ?? new ClientCapabilities(), + ClientInfo = _options.ClientInfo ?? DefaultImplementation, + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult, + cancellationToken: initializationCts.Token).ConfigureAwait(false); + + // Store server information + if (_logger.IsEnabled(LogLevel.Information)) + { + LogServerCapabilitiesReceived(_endpointName, + capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), + serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); + } + + _serverCapabilities = initializeResponse.Capabilities; + _serverInfo = initializeResponse.ServerInfo; + _serverInstructions = initializeResponse.Instructions; + + // Validate protocol version + bool isResponseProtocolValid = + _options.ProtocolVersion is { } optionsProtocol ? optionsProtocol == initializeResponse.ProtocolVersion : + McpSessionHandler.SupportedProtocolVersions.Contains(initializeResponse.ProtocolVersion); + if (!isResponseProtocolValid) + { + LogServerProtocolVersionMismatch(_endpointName, requestProtocol, initializeResponse.ProtocolVersion); + throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}"); + } + + // Send initialized notification + await this.SendNotificationAsync( + NotificationMethods.InitializedNotification, + new InitializedNotificationParams(), + McpJsonUtilities.JsonContext.Default.InitializedNotificationParams, + cancellationToken: initializationCts.Token).ConfigureAwait(false); + + } + catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) + { + LogClientInitializationTimeout(_endpointName); + throw new TimeoutException("Initialization timed out", oce); + } + } + catch (Exception e) + { + LogClientInitializationError(_endpointName, e); + await DisposeAsync().ConfigureAwait(false); + throw; + } + + LogClientConnected(_endpointName); + } + + /// + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + => _sessionHandler.SendRequestAsync(request, cancellationToken); + + /// + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + => _sessionHandler.SendMessageAsync(message, cancellationToken); + + /// + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + => _sessionHandler.RegisterNotificationHandler(method, handler); + + /// + public override async ValueTask DisposeAsync() + { + if (Interlocked.CompareExchange(ref _isDisposed, 1, 0) != 0) + { + return; + } + + await _sessionHandler.DisposeAsync().ConfigureAwait(false); + await _transport.DisposeAsync().ConfigureAwait(false); + } + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] + private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization error.")] + private partial void LogClientInitializationError(string endpointName, Exception exception); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization timed out.")] + private partial void LogClientInitializationTimeout(string endpointName); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client protocol version mismatch with server. Expected '{Expected}', received '{Received}'.")] + private partial void LogServerProtocolVersionMismatch(string endpointName, string expected, string received); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")] + private partial void LogClientConnected(string endpointName); +} diff --git a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs index 76099d0d..d4ed41db 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs @@ -3,10 +3,10 @@ namespace ModelContextProtocol.Client; /// -/// Provides configuration options for creating instances. +/// Provides configuration options for creating instances. /// /// -/// These options are typically passed to when creating a client. +/// These options are typically passed to when creating a client. /// They define client capabilities, protocol version, and other client-specific settings. /// public sealed class McpClientOptions diff --git a/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs b/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs index 43fc759a..5a618242 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs @@ -10,8 +10,8 @@ namespace ModelContextProtocol.Client; /// /// This class provides a client-side wrapper around a prompt defined on an MCP server. It allows /// retrieving the prompt's content by sending a request to the server with optional arguments. -/// Instances of this class are typically obtained by calling -/// or . +/// Instances of this class are typically obtained by calling +/// or . /// /// /// Each prompt has a name and optionally a description, and it can be invoked with arguments @@ -20,9 +20,9 @@ namespace ModelContextProtocol.Client; /// public sealed class McpClientPrompt { - private readonly IMcpClient _client; + private readonly McpClient _client; - internal McpClientPrompt(IMcpClient client, Prompt prompt) + internal McpClientPrompt(McpClient client, Prompt prompt) { _client = client; ProtocolPrompt = prompt; @@ -63,7 +63,7 @@ internal McpClientPrompt(IMcpClient client, Prompt prompt) /// The server will process the request and return a result containing messages or other content. /// /// - /// This is a convenience method that internally calls + /// This is a convenience method that internally calls /// with this prompt's name and arguments. /// /// diff --git a/src/ModelContextProtocol.Core/Client/McpClientResource.cs b/src/ModelContextProtocol.Core/Client/McpClientResource.cs index 06f8aff6..19f11bfd 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientResource.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientResource.cs @@ -9,15 +9,15 @@ namespace ModelContextProtocol.Client; /// /// This class provides a client-side wrapper around a resource defined on an MCP server. It allows /// retrieving the resource's content by sending a request to the server with the resource's URI. -/// Instances of this class are typically obtained by calling -/// or . +/// Instances of this class are typically obtained by calling +/// or . /// /// public sealed class McpClientResource { - private readonly IMcpClient _client; + private readonly McpClient _client; - internal McpClientResource(IMcpClient client, Resource resource) + internal McpClientResource(McpClient client, Resource resource) { _client = client; ProtocolResource = resource; @@ -58,7 +58,7 @@ internal McpClientResource(IMcpClient client, Resource resource) /// A containing the resource's result with content and messages. /// /// - /// This is a convenience method that internally calls . + /// This is a convenience method that internally calls . /// /// public ValueTask ReadAsync( diff --git a/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs b/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs index 4da1bd0c..033f7cf0 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs @@ -9,15 +9,15 @@ namespace ModelContextProtocol.Client; /// /// This class provides a client-side wrapper around a resource template defined on an MCP server. It allows /// retrieving the resource template's content by sending a request to the server with the resource's URI. -/// Instances of this class are typically obtained by calling -/// or . +/// Instances of this class are typically obtained by calling +/// or . /// /// public sealed class McpClientResourceTemplate { - private readonly IMcpClient _client; + private readonly McpClient _client; - internal McpClientResourceTemplate(IMcpClient client, ResourceTemplate resourceTemplate) + internal McpClientResourceTemplate(McpClient client, ResourceTemplate resourceTemplate) { _client = client; ProtocolResourceTemplate = resourceTemplate; diff --git a/src/ModelContextProtocol.Core/Client/McpClientTool.cs b/src/ModelContextProtocol.Core/Client/McpClientTool.cs index 1810e9c5..c7af513e 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientTool.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientTool.cs @@ -6,11 +6,11 @@ namespace ModelContextProtocol.Client; /// -/// Provides an that calls a tool via an . +/// Provides an that calls a tool via an . /// /// /// -/// The class encapsulates an along with a description of +/// The class encapsulates an along with a description of /// a tool available via that client, allowing it to be invoked as an . This enables integration /// with AI models that support function calling capabilities. /// @@ -19,8 +19,8 @@ namespace ModelContextProtocol.Client; /// and without changing the underlying tool functionality. /// /// -/// Typically, you would get instances of this class by calling the -/// or extension methods on an instance. +/// Typically, you would get instances of this class by calling the +/// or extension methods on an instance. /// /// public sealed class McpClientTool : AIFunction @@ -32,13 +32,13 @@ public sealed class McpClientTool : AIFunction ["Strict"] = false, // some MCP schemas may not meet "strict" requirements }); - private readonly IMcpClient _client; + private readonly McpClient _client; private readonly string _name; private readonly string _description; private readonly IProgress? _progress; internal McpClientTool( - IMcpClient client, + McpClient client, Tool tool, JsonSerializerOptions serializerOptions, string? name = null, diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index c4014ed7..823e266f 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -44,7 +44,7 @@ public StreamableHttpClientSessionTransport( _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; // We connect with the initialization request with the MCP transport. This means that any errors won't be observed - // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClientFactory.ConnectAsync + // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClient.ConnectAsync // so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user. SetConnected(); } diff --git a/src/ModelContextProtocol.Core/IMcpEndpoint.cs b/src/ModelContextProtocol.Core/IMcpEndpoint.cs index ea825e68..01221ecd 100644 --- a/src/ModelContextProtocol.Core/IMcpEndpoint.cs +++ b/src/ModelContextProtocol.Core/IMcpEndpoint.cs @@ -1,4 +1,4 @@ -using ModelContextProtocol.Client; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; @@ -26,6 +26,7 @@ namespace ModelContextProtocol; /// All MCP endpoints should be properly disposed after use as they implement . /// /// +[Obsolete($"Use {nameof(McpSession)} instead.")] public interface IMcpEndpoint : IAsyncDisposable { /// Gets an identifier associated with the current MCP session. diff --git a/src/ModelContextProtocol.Core/McpEndpoint.cs b/src/ModelContextProtocol.Core/McpEndpoint.cs deleted file mode 100644 index 0d0ccbb9..00000000 --- a/src/ModelContextProtocol.Core/McpEndpoint.cs +++ /dev/null @@ -1,144 +0,0 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; - -namespace ModelContextProtocol; - -/// -/// Base class for an MCP JSON-RPC endpoint. This covers both MCP clients and servers. -/// It is not supported, nor necessary, to implement both client and server functionality in the same class. -/// If an application needs to act as both a client and a server, it should use separate objects for each. -/// This is especially true as a client represents a connection to one and only one server, and vice versa. -/// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction. -/// -internal abstract partial class McpEndpoint : IAsyncDisposable -{ - /// Cached naming information used for name/version when none is specified. - internal static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); - - private McpSession? _session; - private CancellationTokenSource? _sessionCts; - - private readonly SemaphoreSlim _disposeLock = new(1, 1); - private bool _disposed; - - protected readonly ILogger _logger; - - /// - /// Initializes a new instance of the class. - /// - /// The logger factory. - protected McpEndpoint(ILoggerFactory? loggerFactory = null) - { - _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; - } - - protected RequestHandlers RequestHandlers { get; } = []; - - protected NotificationHandlers NotificationHandlers { get; } = new(); - - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) - => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); - - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - => GetSessionOrThrow().SendMessageAsync(message, cancellationToken); - - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => - GetSessionOrThrow().RegisterNotificationHandler(method, handler); - - /// - /// Gets the name of the endpoint for logging and debug purposes. - /// - public abstract string EndpointName { get; } - - /// - /// Task that processes incoming messages from the transport. - /// - protected Task? MessageProcessingTask { get; private set; } - - protected void InitializeSession(ITransport sessionTransport) - { - _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger); - } - - [MemberNotNull(nameof(MessageProcessingTask))] - protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken) - { - _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken); - MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token); - } - - protected void CancelSession() => _sessionCts?.Cancel(); - - public async ValueTask DisposeAsync() - { - using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); - - if (_disposed) - { - return; - } - _disposed = true; - - await DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - - /// - /// Cleans up the endpoint and releases resources. - /// - /// - public virtual async ValueTask DisposeUnsynchronizedAsync() - { - LogEndpointShuttingDown(EndpointName); - - try - { - if (_sessionCts is not null) - { - await _sessionCts.CancelAsync().ConfigureAwait(false); - } - - if (MessageProcessingTask is not null) - { - try - { - await MessageProcessingTask.ConfigureAwait(false); - } - catch (OperationCanceledException) - { - // Ignore cancellation - } - } - } - finally - { - _session?.Dispose(); - _sessionCts?.Dispose(); - } - - LogEndpointShutDown(EndpointName); - } - - protected McpSession GetSessionOrThrow() - { -#if NET - ObjectDisposedException.ThrowIf(_disposed, this); -#else - if (_disposed) - { - throw new ObjectDisposedException(GetType().Name); - } -#endif - - return _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shutting down.")] - private partial void LogEndpointShuttingDown(string endpointName); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shut down.")] - private partial void LogEndpointShutDown(string endpointName); -} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/McpEndpointExtensions.cs b/src/ModelContextProtocol.Core/McpEndpointExtensions.cs index 4e4abe5c..0f0239fa 100644 --- a/src/ModelContextProtocol.Core/McpEndpointExtensions.cs +++ b/src/ModelContextProtocol.Core/McpEndpointExtensions.cs @@ -1,8 +1,9 @@ -using ModelContextProtocol.Client; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text.Json; -using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol; @@ -34,6 +35,7 @@ public static class McpEndpointExtensions /// The options governing request serialization. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. The task result contains the deserialized result. + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.SendRequestAsync)} instead.")] public static ValueTask SendRequestAsync( this IMcpEndpoint endpoint, string method, @@ -42,53 +44,7 @@ public static ValueTask SendRequestAsync( RequestId requestId = default, CancellationToken cancellationToken = default) where TResult : notnull - { - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - JsonTypeInfo paramsTypeInfo = serializerOptions.GetTypeInfo(); - JsonTypeInfo resultTypeInfo = serializerOptions.GetTypeInfo(); - return SendRequestAsync(endpoint, method, parameters, paramsTypeInfo, resultTypeInfo, requestId, cancellationToken); - } - - /// - /// Sends a JSON-RPC request and attempts to deserialize the result to . - /// - /// The type of the request parameters to serialize from. - /// The type of the result to deserialize to. - /// The MCP client or server instance. - /// The JSON-RPC method name to invoke. - /// Object representing the request parameters. - /// The type information for request parameter serialization. - /// The type information for request parameter deserialization. - /// The request id for the request. - /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains the deserialized result. - internal static async ValueTask SendRequestAsync( - this IMcpEndpoint endpoint, - string method, - TParameters parameters, - JsonTypeInfo parametersTypeInfo, - JsonTypeInfo resultTypeInfo, - RequestId requestId = default, - CancellationToken cancellationToken = default) - where TResult : notnull - { - Throw.IfNull(endpoint); - Throw.IfNullOrWhiteSpace(method); - Throw.IfNull(parametersTypeInfo); - Throw.IfNull(resultTypeInfo); - - JsonRpcRequest jsonRpcRequest = new() - { - Id = requestId, - Method = method, - Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo), - }; - - JsonRpcResponse response = await endpoint.SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false); - return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response."); - } + => AsSessionOrThrow(endpoint).SendRequestAsync(method, parameters, serializerOptions, requestId, cancellationToken); /// /// Sends a parameterless notification to the connected endpoint. @@ -104,12 +60,9 @@ internal static async ValueTask SendRequestAsync( /// changes in state. /// /// + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.SendNotificationAsync)} instead.")] public static Task SendNotificationAsync(this IMcpEndpoint client, string method, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(method); - return client.SendMessageAsync(new JsonRpcNotification { Method = method }, cancellationToken); - } + => AsSessionOrThrow(client).SendNotificationAsync(method, cancellationToken); /// /// Sends a notification with parameters to the connected endpoint. @@ -135,42 +88,14 @@ public static Task SendNotificationAsync(this IMcpEndpoint client, string method /// but custom methods can also be used for application-specific notifications. /// /// + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.SendNotificationAsync)} instead.")] public static Task SendNotificationAsync( this IMcpEndpoint endpoint, string method, TParameters parameters, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - JsonTypeInfo parametersTypeInfo = serializerOptions.GetTypeInfo(); - return SendNotificationAsync(endpoint, method, parameters, parametersTypeInfo, cancellationToken); - } - - /// - /// Sends a notification to the server with parameters. - /// - /// The MCP client or server instance. - /// The JSON-RPC method name to invoke. - /// Object representing the request parameters. - /// The type information for request parameter serialization. - /// The to monitor for cancellation requests. The default is . - internal static Task SendNotificationAsync( - this IMcpEndpoint endpoint, - string method, - TParameters parameters, - JsonTypeInfo parametersTypeInfo, - CancellationToken cancellationToken = default) - { - Throw.IfNull(endpoint); - Throw.IfNullOrWhiteSpace(method); - Throw.IfNull(parametersTypeInfo); - - JsonNode? parametersJson = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo); - return endpoint.SendMessageAsync(new JsonRpcNotification { Method = method, Params = parametersJson }, cancellationToken); - } + => AsSessionOrThrow(endpoint).SendNotificationAsync(method, parameters, serializerOptions, cancellationToken); /// /// Notifies the connected endpoint of progress for a long-running operation. @@ -191,22 +116,33 @@ internal static Task SendNotificationAsync( /// Progress notifications are sent asynchronously and don't block the operation from continuing. /// /// + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.NotifyProgressAsync)} instead.")] public static Task NotifyProgressAsync( this IMcpEndpoint endpoint, ProgressToken progressToken, - ProgressNotificationValue progress, + ProgressNotificationValue progress, CancellationToken cancellationToken = default) + => AsSessionOrThrow(endpoint).NotifyProgressAsync(progressToken, progress, cancellationToken); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#pragma warning disable CS0618 // Type or member is obsolete + private static McpSession AsSessionOrThrow(IMcpEndpoint endpoint, [CallerMemberName] string memberName = "") +#pragma warning restore CS0618 // Type or member is obsolete { - Throw.IfNull(endpoint); + if (endpoint is not McpSession session) + { + ThrowInvalidEndpointType(memberName); + } + + return session; - return endpoint.SendNotificationAsync( - NotificationMethods.ProgressNotification, - new ProgressNotificationParams - { - ProgressToken = progressToken, - Progress = progress, - }, - McpJsonUtilities.JsonContext.Default.ProgressNotificationParams, - cancellationToken); + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowInvalidEndpointType(string memberName) + => throw new InvalidOperationException( + $"Only arguments assignable to '{nameof(McpSession)}' are supported. " + + $"Prefer using '{nameof(McpServer)}.{memberName}' instead, as " + + $"'{nameof(McpEndpointExtensions)}.{memberName}' is obsolete and will be " + + $"removed in the future."); } } diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSession.cs index da954205..8fea9f2d 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSession.cs @@ -1,794 +1,261 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using System.Collections.Concurrent; -using System.Diagnostics; -using System.Diagnostics.Metrics; using System.Text.Json; using System.Text.Json.Nodes; -#if !NET -using System.Threading.Channels; -#endif +using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol; /// -/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers. +/// Represents a client or server Model Context Protocol (MCP) session. /// -internal sealed partial class McpSession : IDisposable +/// +/// +/// The MCP session provides the core communication functionality used by both clients and servers: +/// +/// Sending JSON-RPC requests and receiving responses. +/// Sending notifications to the connected session. +/// Registering handlers for receiving notifications. +/// +/// +/// +/// serves as the base interface for both and +/// interfaces, providing the common functionality needed for MCP protocol +/// communication. Most applications will use these more specific interfaces rather than working with +/// directly. +/// +/// +/// All MCP sessions should be properly disposed after use as they implement . +/// +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract class McpSession : IMcpEndpoint, IAsyncDisposable +#pragma warning restore CS0618 // Type or member is obsolete { - private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram( - "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); - private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram( - "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true); - private static readonly Histogram s_clientOperationDuration = Diagnostics.CreateDurationHistogram( - "mcp.client.operation.duration", "Measures the duration of outbound message.", longBuckets: false); - private static readonly Histogram s_serverOperationDuration = Diagnostics.CreateDurationHistogram( - "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); + /// Gets an identifier associated with the current MCP session. + /// + /// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE. + /// Can return if the session hasn't initialized or if the transport doesn't + /// support multiple sessions (as is the case with STDIO). + /// + public abstract string? SessionId { get; } - /// The latest version of the protocol supported by this implementation. - internal const string LatestProtocolVersion = "2025-06-18"; - - /// All protocol versions supported by this implementation. - internal static readonly string[] SupportedProtocolVersions = - [ - "2024-11-05", - "2025-03-26", - LatestProtocolVersion, - ]; - - private readonly bool _isServer; - private readonly string _transportKind; - private readonly ITransport _transport; - private readonly RequestHandlers _requestHandlers; - private readonly NotificationHandlers _notificationHandlers; - private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); - - private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current; - - /// Collection of requests sent on this session and waiting for responses. - private readonly ConcurrentDictionary> _pendingRequests = []; /// - /// Collection of requests received on this session and currently being handled. The value provides a - /// that can be used to request cancellation of the in-flight handler. + /// Sends a JSON-RPC request to the connected session and waits for a response. /// - private readonly ConcurrentDictionary _handlingRequests = new(); - private readonly ILogger _logger; - - // This _sessionId is solely used to identify the session in telemetry and logs. - private readonly string _sessionId = Guid.NewGuid().ToString("N"); - private long _lastRequestId; - - /// - /// Initializes a new instance of the class. - /// - /// true if this is a server; false if it's a client. - /// An MCP transport implementation. - /// The name of the endpoint for logging and debug purposes. - /// A collection of request handlers. - /// A collection of notification handlers. - /// The logger. - public McpSession( - bool isServer, - ITransport transport, - string endpointName, - RequestHandlers requestHandlers, - NotificationHandlers notificationHandlers, - ILogger logger) - { - Throw.IfNull(transport); - - _transportKind = transport switch - { - StdioClientSessionTransport or StdioServerTransport => "stdio", - StreamClientSessionTransport or StreamServerTransport => "stream", - SseClientSessionTransport or SseResponseStreamTransport => "sse", - StreamableHttpClientSessionTransport or StreamableHttpServerTransport or StreamableHttpPostTransport => "http", - _ => "unknownTransport" - }; - - _isServer = isServer; - _transport = transport; - EndpointName = endpointName; - _requestHandlers = requestHandlers; - _notificationHandlers = notificationHandlers; - _logger = logger ?? NullLogger.Instance; - LogSessionCreated(EndpointName, _sessionId, _transportKind); - } + /// The JSON-RPC request to send. + /// The to monitor for cancellation requests. The default is . + /// A task containing the session's response. + /// The transport is not connected, or another error occurs during request processing. + /// An error occured during request processing. + /// + /// This method provides low-level access to send raw JSON-RPC requests. For most use cases, + /// consider using the strongly-typed methods that provide a more convenient API. + /// + public abstract Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default); /// - /// Gets and sets the name of the endpoint for logging and debug purposes. + /// Sends a JSON-RPC message to the connected session. /// - public string EndpointName { get; set; } + /// + /// The JSON-RPC message to send. This can be any type that implements JsonRpcMessage, such as + /// JsonRpcRequest, JsonRpcResponse, JsonRpcNotification, or JsonRpcError. + /// + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous send operation. + /// The transport is not connected. + /// is . + /// + /// + /// This method provides low-level access to send any JSON-RPC message. For specific message types, + /// consider using the higher-level methods such as or methods + /// on this class that provide a simpler API. + /// + /// + /// The method will serialize the message and transmit it using the underlying transport mechanism. + /// + /// + public abstract Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default); + + /// Registers a handler to be invoked when a notification for the specified method is received. + /// The notification method. + /// The handler to be invoked. + /// An that will remove the registered handler when disposed. + public abstract IAsyncDisposable RegisterNotificationHandler(string method, Func handler); + + /// + public abstract ValueTask DisposeAsync(); /// - /// Starts processing messages from the transport. This method will block until the transport is disconnected. - /// This is generally started in a background task or thread from the initialization logic of the derived class. + /// Sends a JSON-RPC request and attempts to deserialize the result to . /// - public async Task ProcessMessagesAsync(CancellationToken cancellationToken) - { - try - { - await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) - { - LogMessageRead(EndpointName, message.GetType().Name); - - // Fire and forget the message handling to avoid blocking the transport. - if (message.ExecutionContext is null) - { - _ = ProcessMessageAsync(); - } - else - { - // Flow the execution context from the HTTP request corresponding to this message if provided. - ExecutionContext.Run(message.ExecutionContext, _ => _ = ProcessMessageAsync(), null); - } - - async Task ProcessMessageAsync() - { - JsonRpcMessageWithId? messageWithId = message as JsonRpcMessageWithId; - CancellationTokenSource? combinedCts = null; - try - { - // Register before we yield, so that the tracking is guaranteed to be there - // when subsequent messages arrive, even if the asynchronous processing happens - // out of order. - if (messageWithId is not null) - { - combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _handlingRequests[messageWithId.Id] = combinedCts; - } - - // If we await the handler without yielding first, the transport may not be able to read more messages, - // which could lead to a deadlock if the handler sends a message back. -#if NET - await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); -#else - await default(ForceYielding); -#endif - - // Handle the message. - await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - // Only send responses for request errors that aren't user-initiated cancellation. - bool isUserCancellation = - ex is OperationCanceledException && - !cancellationToken.IsCancellationRequested && - combinedCts?.IsCancellationRequested is true; - - if (!isUserCancellation && message is JsonRpcRequest request) - { - LogRequestHandlerException(EndpointName, request.Method, ex); - - JsonRpcErrorDetail detail = ex is McpException mcpe ? - new() - { - Code = (int)mcpe.ErrorCode, - Message = mcpe.Message, - } : - new() - { - Code = (int)McpErrorCode.InternalError, - Message = "An error occurred.", - }; - - await SendMessageAsync(new JsonRpcError - { - Id = request.Id, - JsonRpc = "2.0", - Error = detail, - RelatedTransport = request.RelatedTransport, - }, cancellationToken).ConfigureAwait(false); - } - else if (ex is not OperationCanceledException) - { - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), ex); - } - else - { - LogMessageHandlerException(EndpointName, message.GetType().Name, ex); - } - } - } - finally - { - if (messageWithId is not null) - { - _handlingRequests.TryRemove(messageWithId.Id, out _); - combinedCts!.Dispose(); - } - } - } - } - } - catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) - { - // Normal shutdown - LogEndpointMessageProcessingCanceled(EndpointName); - } - finally - { - // Fail any pending requests, as they'll never be satisfied. - foreach (var entry in _pendingRequests) - { - entry.Value.TrySetException(new IOException("The server shut down unexpectedly.")); - } - } - } - - private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken) - { - Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; - string method = GetMethodName(message); - - long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - - Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? - Diagnostics.ActivitySource.StartActivity( - CreateActivityName(method), - ActivityKind.Server, - parentContext: _propagator.ExtractActivityContext(message), - links: Diagnostics.ActivityLinkFromCurrent()) : - null; - - TagList tags = default; - bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - try - { - if (addTags) - { - AddTags(ref tags, activity, message, method); - } - - switch (message) - { - case JsonRpcRequest request: - var result = await HandleRequest(request, cancellationToken).ConfigureAwait(false); - AddResponseTags(ref tags, activity, result, method); - break; - - case JsonRpcNotification notification: - await HandleNotification(notification, cancellationToken).ConfigureAwait(false); - break; - - case JsonRpcMessageWithId messageWithId: - HandleMessageWithId(message, messageWithId); - break; - - default: - LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); - break; - } - } - catch (Exception e) when (addTags) - { - AddExceptionTags(ref tags, activity, e); - throw; - } - finally - { - FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); - } - } - - private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken) + /// The type of the request parameters to serialize from. + /// The type of the result to deserialize to. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The request id for the request. + /// The options governing request serialization. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the deserialized result. + public ValueTask SendRequestAsync( + string method, + TParameters parameters, + JsonSerializerOptions? serializerOptions = null, + RequestId requestId = default, + CancellationToken cancellationToken = default) + where TResult : notnull { - // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) - if (notification.Method == NotificationMethods.CancelledNotification) - { - try - { - if (GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && - _handlingRequests.TryGetValue(cn.RequestId, out var cts)) - { - await cts.CancelAsync().ConfigureAwait(false); - LogRequestCanceled(EndpointName, cn.RequestId, cn.Reason); - } - } - catch - { - // "Invalid cancellation notifications SHOULD be ignored" - } - } + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); - // Handle user-defined notifications. - await _notificationHandlers.InvokeHandlers(notification.Method, notification, cancellationToken).ConfigureAwait(false); + JsonTypeInfo paramsTypeInfo = serializerOptions.GetTypeInfo(); + JsonTypeInfo resultTypeInfo = serializerOptions.GetTypeInfo(); + return SendRequestAsync(method, parameters, paramsTypeInfo, resultTypeInfo, requestId, cancellationToken); } - private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId messageWithId) - { - if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs)) - { - tcs.TrySetResult(message); - } - else - { - LogNoRequestFoundForMessageWithId(EndpointName, messageWithId.Id); - } - } - - private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) + /// + /// Sends a JSON-RPC request and attempts to deserialize the result to . + /// + /// The type of the request parameters to serialize from. + /// The type of the result to deserialize to. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The type information for request parameter serialization. + /// The type information for request parameter deserialization. + /// The request id for the request. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the deserialized result. + internal async ValueTask SendRequestAsync( + string method, + TParameters parameters, + JsonTypeInfo parametersTypeInfo, + JsonTypeInfo resultTypeInfo, + RequestId requestId = default, + CancellationToken cancellationToken = default) + where TResult : notnull { - if (!_requestHandlers.TryGetValue(request.Method, out var handler)) - { - LogNoHandlerFoundForRequest(EndpointName, request.Method); - throw new McpException($"Method '{request.Method}' is not available.", McpErrorCode.MethodNotFound); - } - - LogRequestHandlerCalled(EndpointName, request.Method); - JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false); - LogRequestHandlerCompleted(EndpointName, request.Method); - - await SendMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = result, - RelatedTransport = request.RelatedTransport, - }, cancellationToken).ConfigureAwait(false); - - return result; - } + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(parametersTypeInfo); + Throw.IfNull(resultTypeInfo); - private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, JsonRpcRequest request) - { - if (!cancellationToken.CanBeCanceled) + JsonRpcRequest jsonRpcRequest = new() { - return default; - } + Id = requestId, + Method = method, + Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo), + }; - return cancellationToken.Register(static objState => - { - var state = (Tuple)objState!; - _ = state.Item1.SendMessageAsync(new JsonRpcNotification - { - Method = NotificationMethods.CancelledNotification, - Params = JsonSerializer.SerializeToNode(new CancelledNotificationParams { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams), - RelatedTransport = state.Item2.RelatedTransport, - }); - }, Tuple.Create(this, request)); + JsonRpcResponse response = await SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false); + return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response."); } - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + /// + /// Sends a parameterless notification to the connected session. + /// + /// The notification method name. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous send operation. + /// + /// + /// This method sends a notification without any parameters. Notifications are one-way messages + /// that don't expect a response. They are commonly used for events, status updates, or to signal + /// changes in state. + /// + /// + public Task SendNotificationAsync(string method, CancellationToken cancellationToken = default) { Throw.IfNullOrWhiteSpace(method); - Throw.IfNull(handler); - - return _notificationHandlers.Register(method, handler); + return SendMessageAsync(new JsonRpcNotification { Method = method }, cancellationToken); } /// - /// Sends a JSON-RPC request to the server. - /// It is strongly recommended use the capability-specific methods instead of this one. - /// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation. + /// Sends a notification with parameters to the connected session. /// - /// The JSON-RPC request to send. + /// The type of the notification parameters to serialize. + /// The JSON-RPC method name for the notification. + /// Object representing the notification parameters. + /// The options governing parameter serialization. If null, default options are used. /// The to monitor for cancellation requests. The default is . - /// A task containing the server's response. - public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; - string method = request.Method; - - long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - using Activity? activity = Diagnostics.ShouldInstrumentMessage(request) ? - Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : - null; - - // Set request ID - if (request.Id.Id is null) - { - request = request.WithId(new RequestId(Interlocked.Increment(ref _lastRequestId))); - } - - _propagator.InjectActivityContext(activity, request); - - TagList tags = default; - bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _pendingRequests[request.Id] = tcs; - try - { - if (addTags) - { - AddTags(ref tags, activity, request, method); - } - - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); - } - else - { - LogSendingRequest(EndpointName, request.Method); - } - - await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false); - - // Now that the request has been sent, register for cancellation. If we registered before, - // a cancellation request could arrive before the server knew about that request ID, in which - // case the server could ignore it. - LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id); - JsonRpcMessage? response; - using (var registration = RegisterCancellation(cancellationToken, request)) - { - response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); - } - - if (response is JsonRpcError error) - { - LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); - throw new McpException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); - } - - if (response is JsonRpcResponse success) - { - if (addTags) - { - AddResponseTags(ref tags, activity, success.Result, method); - } - - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null"); - } - else - { - LogRequestResponseReceived(EndpointName, request.Method); - } - - return success; - } - - // Unexpected response type - LogSendingRequestInvalidResponseType(EndpointName, request.Method); - throw new McpException("Invalid response type"); - } - catch (Exception ex) when (addTags) - { - AddExceptionTags(ref tags, activity, ex); - throw; - } - finally - { - _pendingRequests.TryRemove(request.Id, out _); - FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); - } - } - - public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + /// A task that represents the asynchronous send operation. + /// + /// + /// This method sends a notification with parameters to the connected session. Notifications are one-way + /// messages that don't expect a response, commonly used for events, status updates, or signaling changes. + /// + /// + /// The parameters object is serialized to JSON according to the provided serializer options or the default + /// options if none are specified. + /// + /// + /// The Model Context Protocol defines several standard notification methods in , + /// but custom methods can also be used for application-specific notifications. + /// + /// + public Task SendNotificationAsync( + string method, + TParameters parameters, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) { - Throw.IfNull(message); + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); - cancellationToken.ThrowIfCancellationRequested(); - - Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; - string method = GetMethodName(message); - - long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - using Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? - Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : - null; - - TagList tags = default; - bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - - // propagate trace context - _propagator?.InjectActivityContext(activity, message); - - try - { - if (addTags) - { - AddTags(ref tags, activity, message, method); - } - - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); - } - else - { - LogSendingMessage(EndpointName); - } - - await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false); - - // If the sent notification was a cancellation notification, cancel the pending request's await, as either the - // server won't be sending a response, or per the specification, the response should be ignored. There are inherent - // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. - if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && - GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && - _pendingRequests.TryRemove(cn.RequestId, out var tcs)) - { - tcs.TrySetCanceled(default); - } - } - catch (Exception ex) when (addTags) - { - AddExceptionTags(ref tags, activity, ex); - throw; - } - finally - { - FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); - } - } - - // The JsonRpcMessage should be sent over the RelatedTransport if set. This is used to support the - // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in - // the HTTP response body for the POST request containing the corresponding JSON-RPC request. - private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) - => (message.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); - - private static CancelledNotificationParams? GetCancelledNotificationParams(JsonNode? notificationParams) - { - try - { - return JsonSerializer.Deserialize(notificationParams, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams); - } - catch - { - return null; - } - } - - private string CreateActivityName(string method) => method; - - private static string GetMethodName(JsonRpcMessage message) => - message switch - { - JsonRpcRequest request => request.Method, - JsonRpcNotification notification => notification.Method, - _ => "unknownMethod" - }; - - private void AddTags(ref TagList tags, Activity? activity, JsonRpcMessage message, string method) - { - tags.Add("mcp.method.name", method); - tags.Add("network.transport", _transportKind); - - // TODO: When using SSE transport, add: - // - server.address and server.port on client spans and metrics - // - client.address and client.port on server spans (not metrics because of cardinality) when using SSE transport - if (activity is { IsAllDataRequested: true }) - { - // session and request id have high cardinality, so not applying to metric tags - activity.AddTag("mcp.session.id", _sessionId); - - if (message is JsonRpcMessageWithId withId) - { - activity.AddTag("mcp.request.id", withId.Id.Id?.ToString()); - } - } - - JsonObject? paramsObj = message switch - { - JsonRpcRequest request => request.Params as JsonObject, - JsonRpcNotification notification => notification.Params as JsonObject, - _ => null - }; - - if (paramsObj == null) - { - return; - } - - string? target = null; - switch (method) - { - case RequestMethods.ToolsCall: - case RequestMethods.PromptsGet: - target = GetStringProperty(paramsObj, "name"); - if (target is not null) - { - tags.Add(method == RequestMethods.ToolsCall ? "mcp.tool.name" : "mcp.prompt.name", target); - } - break; - - case RequestMethods.ResourcesRead: - case RequestMethods.ResourcesSubscribe: - case RequestMethods.ResourcesUnsubscribe: - case NotificationMethods.ResourceUpdatedNotification: - target = GetStringProperty(paramsObj, "uri"); - if (target is not null) - { - tags.Add("mcp.resource.uri", target); - } - break; - } - - if (activity is { IsAllDataRequested: true }) - { - activity.DisplayName = target == null ? method : $"{method} {target}"; - } - } - - private static void AddExceptionTags(ref TagList tags, Activity? activity, Exception e) - { - if (e is AggregateException ae && ae.InnerException is not null and not AggregateException) - { - e = ae.InnerException; - } - - int? intErrorCode = - (int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode : - e is JsonException ? (int)McpErrorCode.ParseError : - null; - - string? errorType = intErrorCode?.ToString() ?? e.GetType().FullName; - tags.Add("error.type", errorType); - if (intErrorCode is not null) - { - tags.Add("rpc.jsonrpc.error_code", errorType); - } - - if (activity is { IsAllDataRequested: true }) - { - activity.SetStatus(ActivityStatusCode.Error, e.Message); - } + JsonTypeInfo parametersTypeInfo = serializerOptions.GetTypeInfo(); + return SendNotificationAsync(method, parameters, parametersTypeInfo, cancellationToken); } - private static void AddResponseTags(ref TagList tags, Activity? activity, JsonNode? response, string method) + /// + /// Sends a notification to the server with parameters. + /// + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The type information for request parameter serialization. + /// The to monitor for cancellation requests. The default is . + internal Task SendNotificationAsync( + string method, + TParameters parameters, + JsonTypeInfo parametersTypeInfo, + CancellationToken cancellationToken = default) { - if (response is JsonObject jsonObject - && jsonObject.TryGetPropertyValue("isError", out var isError) - && isError?.GetValueKind() == JsonValueKind.True) - { - if (activity is { IsAllDataRequested: true }) - { - string? content = null; - if (jsonObject.TryGetPropertyValue("content", out var prop) && prop != null) - { - content = prop.ToJsonString(); - } - - activity.SetStatus(ActivityStatusCode.Error, content); - } + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(parametersTypeInfo); - tags.Add("error.type", method == RequestMethods.ToolsCall ? "tool_error" : "_OTHER"); - } + JsonNode? parametersJson = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo); + return SendMessageAsync(new JsonRpcNotification { Method = method, Params = parametersJson }, cancellationToken); } - private static void FinalizeDiagnostics( - Activity? activity, long? startingTimestamp, Histogram durationMetric, ref TagList tags) + /// + /// Notifies the connected session of progress for a long-running operation. + /// + /// The identifying the operation for which progress is being reported. + /// The progress update to send, containing information such as percentage complete or status message. + /// The to monitor for cancellation requests. The default is . + /// A task representing the completion of the notification operation (not the operation being tracked). + /// The current session instance is . + /// + /// + /// This method sends a progress notification to the connected session using the Model Context Protocol's + /// standardized progress notification format. Progress updates are identified by a + /// that allows the recipient to correlate multiple updates with a specific long-running operation. + /// + /// + /// Progress notifications are sent asynchronously and don't block the operation from continuing. + /// + /// + public Task NotifyProgressAsync( + ProgressToken progressToken, + ProgressNotificationValue progress, + CancellationToken cancellationToken = default) { - try - { - if (startingTimestamp is not null) - { - durationMetric.Record(GetElapsed(startingTimestamp.Value).TotalSeconds, tags); - } - - if (activity is { IsAllDataRequested: true }) + return SendNotificationAsync( + NotificationMethods.ProgressNotification, + new ProgressNotificationParams { - foreach (var tag in tags) - { - activity.AddTag(tag.Key, tag.Value); - } - } - } - finally - { - activity?.Dispose(); - } + ProgressToken = progressToken, + Progress = progress, + }, + McpJsonUtilities.JsonContext.Default.ProgressNotificationParams, + cancellationToken); } - - public void Dispose() - { - Histogram durationMetric = _isServer ? s_serverSessionDuration : s_clientSessionDuration; - if (durationMetric.Enabled) - { - TagList tags = default; - tags.Add("network.transport", _transportKind); - - // TODO: Add server.address and server.port on client-side when using SSE transport, - // client.* attributes are not added to metrics because of cardinality - durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); - } - - // Complete all pending requests with cancellation - foreach (var entry in _pendingRequests) - { - entry.Value.TrySetCanceled(); - } - - _pendingRequests.Clear(); - LogSessionDisposed(EndpointName, _sessionId, _transportKind); - } - -#if !NET - private static readonly double s_timestampToTicks = TimeSpan.TicksPerSecond / (double)Stopwatch.Frequency; -#endif - - private static TimeSpan GetElapsed(long startingTimestamp) => -#if NET - Stopwatch.GetElapsedTime(startingTimestamp); -#else - new((long)(s_timestampToTicks * (Stopwatch.GetTimestamp() - startingTimestamp))); -#endif - - private static string? GetStringProperty(JsonObject parameters, string propName) - { - if (parameters.TryGetPropertyValue(propName, out var prop) && prop?.GetValueKind() is JsonValueKind.String) - { - return prop.GetValue(); - } - - return null; - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} message processing canceled.")] - private partial void LogEndpointMessageProcessingCanceled(string endpointName); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler called.")] - private partial void LogRequestHandlerCalled(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler completed.")] - private partial void LogRequestHandlerCompleted(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} method '{Method}' request handler failed.")] - private partial void LogRequestHandlerException(string endpointName, string method, Exception exception); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")] - private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")] - private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received invalid response for method '{Method}'.")] - private partial void LogSendingRequestInvalidResponseType(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending method '{Method}' request.")] - private partial void LogSendingRequest(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending method '{Method}' request. Request: '{Request}'.")] - private partial void LogSendingRequestSensitive(string endpointName, string method, string request); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} canceled request '{RequestId}' per client notification. Reason: '{Reason}'.")] - private partial void LogRequestCanceled(string endpointName, RequestId requestId, string? reason); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method}")] - private partial void LogRequestResponseReceived(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method}. Response: '{Response}'.")] - private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} read {MessageType} message from channel.")] - private partial void LogMessageRead(string endpointName, string messageType); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} message handler {MessageType} failed.")] - private partial void LogMessageHandlerException(string endpointName, string messageType, Exception exception); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} message handler {MessageType} failed. Message: '{Message}'.")] - private partial void LogMessageHandlerExceptionSensitive(string endpointName, string messageType, string message, Exception exception); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received unexpected {MessageType} message type.")] - private partial void LogEndpointHandlerUnexpectedMessageType(string endpointName, string messageType); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received request for method '{Method}', but no handler is available.")] - private partial void LogNoHandlerFoundForRequest(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}'.")] - private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending message.")] - private partial void LogSendingMessage(string endpointName); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending message. Message: '{Message}'.")] - private partial void LogSendingMessageSensitive(string endpointName, string message); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} created with transport {TransportKind}")] - private partial void LogSessionCreated(string endpointName, string sessionId, string transportKind); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} disposed with transport {TransportKind}")] - private partial void LogSessionDisposed(string endpointName, string sessionId, string transportKind); } diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs new file mode 100644 index 00000000..463e796a --- /dev/null +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -0,0 +1,830 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.Metrics; +using System.Text.Json; +using System.Text.Json.Nodes; +#if !NET +using System.Threading.Channels; +#endif + +namespace ModelContextProtocol; + +/// +/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers. +/// +internal sealed partial class McpSessionHandler : IAsyncDisposable +{ + private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); + private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true); + private static readonly Histogram s_clientOperationDuration = Diagnostics.CreateDurationHistogram( + "mcp.client.operation.duration", "Measures the duration of outbound message.", longBuckets: false); + private static readonly Histogram s_serverOperationDuration = Diagnostics.CreateDurationHistogram( + "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); + + /// The latest version of the protocol supported by this implementation. + internal const string LatestProtocolVersion = "2025-06-18"; + + /// All protocol versions supported by this implementation. + internal static readonly string[] SupportedProtocolVersions = + [ + "2024-11-05", + "2025-03-26", + LatestProtocolVersion, + ]; + + private readonly bool _isServer; + private readonly string _transportKind; + private readonly ITransport _transport; + private readonly RequestHandlers _requestHandlers; + private readonly NotificationHandlers _notificationHandlers; + private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); + + private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current; + + /// Collection of requests sent on this session and waiting for responses. + private readonly ConcurrentDictionary> _pendingRequests = []; + /// + /// Collection of requests received on this session and currently being handled. The value provides a + /// that can be used to request cancellation of the in-flight handler. + /// + private readonly ConcurrentDictionary _handlingRequests = new(); + private readonly ILogger _logger; + + // This _sessionId is solely used to identify the session in telemetry and logs. + private readonly string _sessionId = Guid.NewGuid().ToString("N"); + private long _lastRequestId; + + private CancellationTokenSource? _messageProcessingCts; + private Task? _messageProcessingTask; + + /// + /// Initializes a new instance of the class. + /// + /// true if this is a server; false if it's a client. + /// An MCP transport implementation. + /// The name of the endpoint for logging and debug purposes. + /// A collection of request handlers. + /// A collection of notification handlers. + /// The logger. + public McpSessionHandler( + bool isServer, + ITransport transport, + string endpointName, + RequestHandlers requestHandlers, + NotificationHandlers notificationHandlers, + ILogger logger) + { + Throw.IfNull(transport); + + _transportKind = transport switch + { + StdioClientSessionTransport or StdioServerTransport => "stdio", + StreamClientSessionTransport or StreamServerTransport => "stream", + SseClientSessionTransport or SseResponseStreamTransport => "sse", + StreamableHttpClientSessionTransport or StreamableHttpServerTransport or StreamableHttpPostTransport => "http", + _ => "unknownTransport" + }; + + _isServer = isServer; + _transport = transport; + EndpointName = endpointName; + _requestHandlers = requestHandlers; + _notificationHandlers = notificationHandlers; + _logger = logger ?? NullLogger.Instance; + LogSessionCreated(EndpointName, _sessionId, _transportKind); + } + + /// + /// Gets and sets the name of the endpoint for logging and debug purposes. + /// + public string EndpointName { get; set; } + + /// + /// Starts processing messages from the transport. This method will block until the transport is disconnected. + /// This is generally started in a background task or thread from the initialization logic of the derived class. + /// + public Task ProcessMessagesAsync(CancellationToken cancellationToken) + { + if (_messageProcessingTask is not null) + { + throw new InvalidOperationException("The message processing loop has already started."); + } + + Debug.Assert(_messageProcessingCts is null); + + _messageProcessingCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _messageProcessingTask = ProcessMessagesCoreAsync(_messageProcessingCts.Token); + return _messageProcessingTask; + } + + private async Task ProcessMessagesCoreAsync(CancellationToken cancellationToken) + { + try + { + await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) + { + LogMessageRead(EndpointName, message.GetType().Name); + + // Fire and forget the message handling to avoid blocking the transport. + if (message.ExecutionContext is null) + { + _ = ProcessMessageAsync(); + } + else + { + // Flow the execution context from the HTTP request corresponding to this message if provided. + ExecutionContext.Run(message.ExecutionContext, _ => _ = ProcessMessageAsync(), null); + } + + async Task ProcessMessageAsync() + { + JsonRpcMessageWithId? messageWithId = message as JsonRpcMessageWithId; + CancellationTokenSource? combinedCts = null; + try + { + // Register before we yield, so that the tracking is guaranteed to be there + // when subsequent messages arrive, even if the asynchronous processing happens + // out of order. + if (messageWithId is not null) + { + combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _handlingRequests[messageWithId.Id] = combinedCts; + } + + // If we await the handler without yielding first, the transport may not be able to read more messages, + // which could lead to a deadlock if the handler sends a message back. +#if NET + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); +#else + await default(ForceYielding); +#endif + + // Handle the message. + await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + // Only send responses for request errors that aren't user-initiated cancellation. + bool isUserCancellation = + ex is OperationCanceledException && + !cancellationToken.IsCancellationRequested && + combinedCts?.IsCancellationRequested is true; + + if (!isUserCancellation && message is JsonRpcRequest request) + { + LogRequestHandlerException(EndpointName, request.Method, ex); + + JsonRpcErrorDetail detail = ex is McpException mcpe ? + new() + { + Code = (int)mcpe.ErrorCode, + Message = mcpe.Message, + } : + new() + { + Code = (int)McpErrorCode.InternalError, + Message = "An error occurred.", + }; + + await SendMessageAsync(new JsonRpcError + { + Id = request.Id, + JsonRpc = "2.0", + Error = detail, + RelatedTransport = request.RelatedTransport, + }, cancellationToken).ConfigureAwait(false); + } + else if (ex is not OperationCanceledException) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), ex); + } + else + { + LogMessageHandlerException(EndpointName, message.GetType().Name, ex); + } + } + } + finally + { + if (messageWithId is not null) + { + _handlingRequests.TryRemove(messageWithId.Id, out _); + combinedCts!.Dispose(); + } + } + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // Normal shutdown + LogEndpointMessageProcessingCanceled(EndpointName); + } + finally + { + // Fail any pending requests, as they'll never be satisfied. + foreach (var entry in _pendingRequests) + { + entry.Value.TrySetException(new IOException("The server shut down unexpectedly.")); + } + } + } + + private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken) + { + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + + Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? + Diagnostics.ActivitySource.StartActivity( + CreateActivityName(method), + ActivityKind.Server, + parentContext: _propagator.ExtractActivityContext(message), + links: Diagnostics.ActivityLinkFromCurrent()) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + try + { + if (addTags) + { + AddTags(ref tags, activity, message, method); + } + + switch (message) + { + case JsonRpcRequest request: + var result = await HandleRequest(request, cancellationToken).ConfigureAwait(false); + AddResponseTags(ref tags, activity, result, method); + break; + + case JsonRpcNotification notification: + await HandleNotification(notification, cancellationToken).ConfigureAwait(false); + break; + + case JsonRpcMessageWithId messageWithId: + HandleMessageWithId(message, messageWithId); + break; + + default: + LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); + break; + } + } + catch (Exception e) when (addTags) + { + AddExceptionTags(ref tags, activity, e); + throw; + } + finally + { + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); + } + } + + private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken) + { + // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) + if (notification.Method == NotificationMethods.CancelledNotification) + { + try + { + if (GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && + _handlingRequests.TryGetValue(cn.RequestId, out var cts)) + { + await cts.CancelAsync().ConfigureAwait(false); + LogRequestCanceled(EndpointName, cn.RequestId, cn.Reason); + } + } + catch + { + // "Invalid cancellation notifications SHOULD be ignored" + } + } + + // Handle user-defined notifications. + await _notificationHandlers.InvokeHandlers(notification.Method, notification, cancellationToken).ConfigureAwait(false); + } + + private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId messageWithId) + { + if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs)) + { + tcs.TrySetResult(message); + } + else + { + LogNoRequestFoundForMessageWithId(EndpointName, messageWithId.Id); + } + } + + private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) + { + if (!_requestHandlers.TryGetValue(request.Method, out var handler)) + { + LogNoHandlerFoundForRequest(EndpointName, request.Method); + throw new McpException($"Method '{request.Method}' is not available.", McpErrorCode.MethodNotFound); + } + + LogRequestHandlerCalled(EndpointName, request.Method); + JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false); + LogRequestHandlerCompleted(EndpointName, request.Method); + + await SendMessageAsync(new JsonRpcResponse + { + Id = request.Id, + Result = result, + RelatedTransport = request.RelatedTransport, + }, cancellationToken).ConfigureAwait(false); + + return result; + } + + private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, JsonRpcRequest request) + { + if (!cancellationToken.CanBeCanceled) + { + return default; + } + + return cancellationToken.Register(static objState => + { + var state = (Tuple)objState!; + _ = state.Item1.SendMessageAsync(new JsonRpcNotification + { + Method = NotificationMethods.CancelledNotification, + Params = JsonSerializer.SerializeToNode(new CancelledNotificationParams { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams), + RelatedTransport = state.Item2.RelatedTransport, + }); + }, Tuple.Create(this, request)); + } + + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + { + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(handler); + + return _notificationHandlers.Register(method, handler); + } + + /// + /// Sends a JSON-RPC request to the server. + /// It is strongly recommended use the capability-specific methods instead of this one. + /// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation. + /// + /// The JSON-RPC request to send. + /// The to monitor for cancellation requests. The default is . + /// A task containing the server's response. + public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) + { + Throw.IfNull(request); + + cancellationToken.ThrowIfCancellationRequested(); + + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; + string method = request.Method; + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ShouldInstrumentMessage(request) ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : + null; + + // Set request ID + if (request.Id.Id is null) + { + request = request.WithId(new RequestId(Interlocked.Increment(ref _lastRequestId))); + } + + _propagator.InjectActivityContext(activity, request); + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _pendingRequests[request.Id] = tcs; + try + { + if (addTags) + { + AddTags(ref tags, activity, request, method); + } + + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); + } + else + { + LogSendingRequest(EndpointName, request.Method); + } + + await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false); + + // Now that the request has been sent, register for cancellation. If we registered before, + // a cancellation request could arrive before the server knew about that request ID, in which + // case the server could ignore it. + LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id); + JsonRpcMessage? response; + using (var registration = RegisterCancellation(cancellationToken, request)) + { + response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } + + if (response is JsonRpcError error) + { + LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); + throw new McpException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); + } + + if (response is JsonRpcResponse success) + { + if (addTags) + { + AddResponseTags(ref tags, activity, success.Result, method); + } + + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null"); + } + else + { + LogRequestResponseReceived(EndpointName, request.Method); + } + + return success; + } + + // Unexpected response type + LogSendingRequestInvalidResponseType(EndpointName, request.Method); + throw new McpException("Invalid response type"); + } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, activity, ex); + throw; + } + finally + { + _pendingRequests.TryRemove(request.Id, out _); + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); + } + } + + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + Throw.IfNull(message); + + cancellationToken.ThrowIfCancellationRequested(); + + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + + // propagate trace context + _propagator?.InjectActivityContext(activity, message); + + try + { + if (addTags) + { + AddTags(ref tags, activity, message, method); + } + + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); + } + else + { + LogSendingMessage(EndpointName); + } + + await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false); + + // If the sent notification was a cancellation notification, cancel the pending request's await, as either the + // server won't be sending a response, or per the specification, the response should be ignored. There are inherent + // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. + if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && + GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && + _pendingRequests.TryRemove(cn.RequestId, out var tcs)) + { + tcs.TrySetCanceled(default); + } + } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, activity, ex); + throw; + } + finally + { + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); + } + } + + // The JsonRpcMessage should be sent over the RelatedTransport if set. This is used to support the + // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in + // the HTTP response body for the POST request containing the corresponding JSON-RPC request. + private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) + => (message.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); + + private static CancelledNotificationParams? GetCancelledNotificationParams(JsonNode? notificationParams) + { + try + { + return JsonSerializer.Deserialize(notificationParams, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams); + } + catch + { + return null; + } + } + + private string CreateActivityName(string method) => method; + + private static string GetMethodName(JsonRpcMessage message) => + message switch + { + JsonRpcRequest request => request.Method, + JsonRpcNotification notification => notification.Method, + _ => "unknownMethod" + }; + + private void AddTags(ref TagList tags, Activity? activity, JsonRpcMessage message, string method) + { + tags.Add("mcp.method.name", method); + tags.Add("network.transport", _transportKind); + + // TODO: When using SSE transport, add: + // - server.address and server.port on client spans and metrics + // - client.address and client.port on server spans (not metrics because of cardinality) when using SSE transport + if (activity is { IsAllDataRequested: true }) + { + // session and request id have high cardinality, so not applying to metric tags + activity.AddTag("mcp.session.id", _sessionId); + + if (message is JsonRpcMessageWithId withId) + { + activity.AddTag("mcp.request.id", withId.Id.Id?.ToString()); + } + } + + JsonObject? paramsObj = message switch + { + JsonRpcRequest request => request.Params as JsonObject, + JsonRpcNotification notification => notification.Params as JsonObject, + _ => null + }; + + if (paramsObj == null) + { + return; + } + + string? target = null; + switch (method) + { + case RequestMethods.ToolsCall: + case RequestMethods.PromptsGet: + target = GetStringProperty(paramsObj, "name"); + if (target is not null) + { + tags.Add(method == RequestMethods.ToolsCall ? "mcp.tool.name" : "mcp.prompt.name", target); + } + break; + + case RequestMethods.ResourcesRead: + case RequestMethods.ResourcesSubscribe: + case RequestMethods.ResourcesUnsubscribe: + case NotificationMethods.ResourceUpdatedNotification: + target = GetStringProperty(paramsObj, "uri"); + if (target is not null) + { + tags.Add("mcp.resource.uri", target); + } + break; + } + + if (activity is { IsAllDataRequested: true }) + { + activity.DisplayName = target == null ? method : $"{method} {target}"; + } + } + + private static void AddExceptionTags(ref TagList tags, Activity? activity, Exception e) + { + if (e is AggregateException ae && ae.InnerException is not null and not AggregateException) + { + e = ae.InnerException; + } + + int? intErrorCode = + (int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode : + e is JsonException ? (int)McpErrorCode.ParseError : + null; + + string? errorType = intErrorCode?.ToString() ?? e.GetType().FullName; + tags.Add("error.type", errorType); + if (intErrorCode is not null) + { + tags.Add("rpc.jsonrpc.error_code", errorType); + } + + if (activity is { IsAllDataRequested: true }) + { + activity.SetStatus(ActivityStatusCode.Error, e.Message); + } + } + + private static void AddResponseTags(ref TagList tags, Activity? activity, JsonNode? response, string method) + { + if (response is JsonObject jsonObject + && jsonObject.TryGetPropertyValue("isError", out var isError) + && isError?.GetValueKind() == JsonValueKind.True) + { + if (activity is { IsAllDataRequested: true }) + { + string? content = null; + if (jsonObject.TryGetPropertyValue("content", out var prop) && prop != null) + { + content = prop.ToJsonString(); + } + + activity.SetStatus(ActivityStatusCode.Error, content); + } + + tags.Add("error.type", method == RequestMethods.ToolsCall ? "tool_error" : "_OTHER"); + } + } + + private static void FinalizeDiagnostics( + Activity? activity, long? startingTimestamp, Histogram durationMetric, ref TagList tags) + { + try + { + if (startingTimestamp is not null) + { + durationMetric.Record(GetElapsed(startingTimestamp.Value).TotalSeconds, tags); + } + + if (activity is { IsAllDataRequested: true }) + { + foreach (var tag in tags) + { + activity.AddTag(tag.Key, tag.Value); + } + } + } + finally + { + activity?.Dispose(); + } + } + + public async ValueTask DisposeAsync() + { + Histogram durationMetric = _isServer ? s_serverSessionDuration : s_clientSessionDuration; + if (durationMetric.Enabled) + { + TagList tags = default; + tags.Add("network.transport", _transportKind); + + // TODO: Add server.address and server.port on client-side when using SSE transport, + // client.* attributes are not added to metrics because of cardinality + durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); + } + + foreach (var entry in _pendingRequests) + { + entry.Value.TrySetCanceled(); + } + + _pendingRequests.Clear(); + + if (_messageProcessingCts is not null) + { + await _messageProcessingCts.CancelAsync().ConfigureAwait(false); + } + + if (_messageProcessingTask is not null) + { + try + { + await _messageProcessingTask.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Ignore cancellation + } + } + + LogSessionDisposed(EndpointName, _sessionId, _transportKind); + } + +#if !NET + private static readonly double s_timestampToTicks = TimeSpan.TicksPerSecond / (double)Stopwatch.Frequency; +#endif + + private static TimeSpan GetElapsed(long startingTimestamp) => +#if NET + Stopwatch.GetElapsedTime(startingTimestamp); +#else + new((long)(s_timestampToTicks * (Stopwatch.GetTimestamp() - startingTimestamp))); +#endif + + private static string? GetStringProperty(JsonObject parameters, string propName) + { + if (parameters.TryGetPropertyValue(propName, out var prop) && prop?.GetValueKind() is JsonValueKind.String) + { + return prop.GetValue(); + } + + return null; + } + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} message processing canceled.")] + private partial void LogEndpointMessageProcessingCanceled(string endpointName); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler called.")] + private partial void LogRequestHandlerCalled(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler completed.")] + private partial void LogRequestHandlerCompleted(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} method '{Method}' request handler failed.")] + private partial void LogRequestHandlerException(string endpointName, string method, Exception exception); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")] + private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")] + private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received invalid response for method '{Method}'.")] + private partial void LogSendingRequestInvalidResponseType(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending method '{Method}' request.")] + private partial void LogSendingRequest(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending method '{Method}' request. Request: '{Request}'.")] + private partial void LogSendingRequestSensitive(string endpointName, string method, string request); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} canceled request '{RequestId}' per client notification. Reason: '{Reason}'.")] + private partial void LogRequestCanceled(string endpointName, RequestId requestId, string? reason); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method}")] + private partial void LogRequestResponseReceived(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method}. Response: '{Response}'.")] + private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} read {MessageType} message from channel.")] + private partial void LogMessageRead(string endpointName, string messageType); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} message handler {MessageType} failed.")] + private partial void LogMessageHandlerException(string endpointName, string messageType, Exception exception); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} message handler {MessageType} failed. Message: '{Message}'.")] + private partial void LogMessageHandlerExceptionSensitive(string endpointName, string messageType, string message, Exception exception); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received unexpected {MessageType} message type.")] + private partial void LogEndpointHandlerUnexpectedMessageType(string endpointName, string messageType); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received request for method '{Method}', but no handler is available.")] + private partial void LogNoHandlerFoundForRequest(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}'.")] + private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending message.")] + private partial void LogSendingMessage(string endpointName); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending message. Message: '{Message}'.")] + private partial void LogSendingMessageSensitive(string endpointName, string message); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} created with transport {TransportKind}")] + private partial void LogSessionCreated(string endpointName, string sessionId, string transportKind); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} disposed with transport {TransportKind}")] + private partial void LogSessionDisposed(string endpointName, string sessionId, string transportKind); +} diff --git a/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs b/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs index ebe69813..c065ed6c 100644 --- a/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs +++ b/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs @@ -44,7 +44,7 @@ public sealed class ClientCapabilities /// server requests for listing root URIs. Root URIs serve as entry points for resource navigation in the protocol. /// /// - /// The server can use to request the list of + /// The server can use to request the list of /// available roots from the client, which will trigger the client's . /// /// @@ -78,7 +78,7 @@ public sealed class ClientCapabilities /// /// /// Handlers provided via will be registered with the client for the lifetime of the client. - /// For transient handlers, may be used to register a handler that can + /// For transient handlers, may be used to register a handler that can /// then be unregistered by disposing of the returned from the method. /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/ITransport.cs b/src/ModelContextProtocol.Core/Protocol/ITransport.cs index e35b3a6f..148472e9 100644 --- a/src/ModelContextProtocol.Core/Protocol/ITransport.cs +++ b/src/ModelContextProtocol.Core/Protocol/ITransport.cs @@ -62,8 +62,8 @@ public interface ITransport : IAsyncDisposable /// /// /// This is a core method used by higher-level abstractions in the MCP protocol implementation. - /// Most client code should use the higher-level methods provided by , - /// , , or , + /// Most client code should use the higher-level methods provided by , + /// , or , /// rather than accessing this method directly. /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs index b3176937..65003667 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs @@ -44,7 +44,7 @@ private protected JsonRpcMessage() /// /// /// This is used to support the Streamable HTTP transport in its default stateful mode. In this mode, - /// the outlives the initial HTTP request context it was created on, and new + /// the outlives the initial HTTP request context it was created on, and new /// JSON-RPC messages can originate from future HTTP requests. This allows the transport to flow the /// context with the JSON-RPC message. This is particularly useful for enabling IHttpContextAccessor /// in tool calls. diff --git a/src/ModelContextProtocol.Core/Protocol/Reference.cs b/src/ModelContextProtocol.Core/Protocol/Reference.cs index a9c87fe4..af95cf33 100644 --- a/src/ModelContextProtocol.Core/Protocol/Reference.cs +++ b/src/ModelContextProtocol.Core/Protocol/Reference.cs @@ -12,7 +12,7 @@ namespace ModelContextProtocol.Protocol; /// /// /// -/// References are commonly used with to request completion suggestions for arguments, +/// References are commonly used with to request completion suggestions for arguments, /// and with other methods that need to reference resources or prompts. /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs b/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs index 6e0f1190..7828ce29 100644 --- a/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs @@ -34,7 +34,7 @@ public sealed class SamplingCapability /// generated content. /// /// - /// You can create a handler using the extension + /// You can create a handler using the extension /// method with any implementation of . /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs b/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs index 6a4b2e62..023a869a 100644 --- a/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs +++ b/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs @@ -77,7 +77,7 @@ public sealed class ServerCapabilities /// /// /// Handlers provided via will be registered with the server for the lifetime of the server. - /// For transient handlers, may be used to register a handler that can + /// For transient handlers, may be used to register a handler that can /// then be unregistered by disposing of the returned from the method. /// /// diff --git a/src/ModelContextProtocol.Core/README.md b/src/ModelContextProtocol.Core/README.md index beb365c8..f6cffaf6 100644 --- a/src/ModelContextProtocol.Core/README.md +++ b/src/ModelContextProtocol.Core/README.md @@ -27,8 +27,8 @@ dotnet add package ModelContextProtocol.Core --prerelease ## Getting Started (Client) -To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `IMcpClient` -to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. +To get started writing a client, the `McpClient.CreateAsync` method is used to instantiate and connect an `McpClient` +to a server. Once you have an `McpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. ```csharp var clientTransport = new StdioClientTransport(new StdioClientTransportOptions @@ -38,7 +38,7 @@ var clientTransport = new StdioClientTransport(new StdioClientTransportOptions Arguments = ["-y", "@modelcontextprotocol/server-everything"], }); -var client = await McpClientFactory.CreateAsync(clientTransport); +var client = await McpClient.CreateAsync(clientTransport); // Print the list of tools available from the server. foreach (var tool in await client.ListToolsAsync()) diff --git a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs b/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs index 3372072f..df800617 100644 --- a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs +++ b/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs @@ -17,13 +17,13 @@ internal sealed class RequestServiceProvider( /// Gets whether the specified type is in the list of additional types this service provider wraps around the one in a provided request's services. public static bool IsAugmentedWith(Type serviceType) => serviceType == typeof(RequestContext) || - serviceType == typeof(IMcpServer) || + serviceType == typeof(McpServer) || serviceType == typeof(IProgress); /// public object? GetService(Type serviceType) => serviceType == typeof(RequestContext) ? request : - serviceType == typeof(IMcpServer) ? request.Server : + serviceType == typeof(McpServer) ? request.Server : serviceType == typeof(IProgress) ? (request.Params?.ProgressToken is { } progressToken ? new TokenProgress(request.Server, progressToken) : NullProgress.Instance) : innerServices?.GetService(serviceType); diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index d286d1ef..791a47c8 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -3,31 +3,30 @@ namespace ModelContextProtocol.Server; -internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer +internal sealed class DestinationBoundMcpServer(McpServerImpl server, ITransport? transport) : McpServer { - public string EndpointName => server.EndpointName; - public string? SessionId => transport?.SessionId ?? server.SessionId; - public ClientCapabilities? ClientCapabilities => server.ClientCapabilities; - public Implementation? ClientInfo => server.ClientInfo; - public McpServerOptions ServerOptions => server.ServerOptions; - public IServiceProvider? Services => server.Services; - public LoggingLevel? LoggingLevel => server.LoggingLevel; + public override string? SessionId => transport?.SessionId ?? server.SessionId; + public override ClientCapabilities? ClientCapabilities => server.ClientCapabilities; + public override Implementation? ClientInfo => server.ClientInfo; + public override McpServerOptions ServerOptions => server.ServerOptions; + public override IServiceProvider? Services => server.Services; + public override LoggingLevel? LoggingLevel => server.LoggingLevel; - public ValueTask DisposeAsync() => server.DisposeAsync(); + public override ValueTask DisposeAsync() => server.DisposeAsync(); - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler); + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler); // This will throw because the server must already be running for this class to be constructed, but it should give us a good Exception message. - public Task RunAsync(CancellationToken cancellationToken) => server.RunAsync(cancellationToken); + public override Task RunAsync(CancellationToken cancellationToken) => server.RunAsync(cancellationToken); - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { Debug.Assert(message.RelatedTransport is null); message.RelatedTransport = transport; return server.SendMessageAsync(message, cancellationToken); } - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) { Debug.Assert(request.RelatedTransport is null); request.RelatedTransport = transport; diff --git a/src/ModelContextProtocol.Core/Server/IMcpServer.cs b/src/ModelContextProtocol.Core/Server/IMcpServer.cs index ec2b87ad..31131f81 100644 --- a/src/ModelContextProtocol.Core/Server/IMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/IMcpServer.cs @@ -1,10 +1,11 @@ -using ModelContextProtocol.Protocol; +using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Server; /// /// Represents an instance of a Model Context Protocol (MCP) server that connects to and communicates with an MCP client. /// +[Obsolete($"Use {nameof(McpServer)} instead.")] public interface IMcpServer : IMcpEndpoint { /// diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 6c5858f9..87921b38 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -1,597 +1,418 @@ -using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using System.Runtime.CompilerServices; -using System.Text.Json.Serialization.Metadata; +using System.Text; +using System.Text.Json; namespace ModelContextProtocol.Server; -/// -internal sealed class McpServer : McpEndpoint, IMcpServer +/// +/// Represents an instance of a Model Context Protocol (MCP) server that connects to and communicates with an MCP client. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract class McpServer : McpSession, IMcpServer +#pragma warning restore CS0618 // Type or member is obsolete { - internal static Implementation DefaultImplementation { get; } = new() - { - Name = DefaultAssemblyName.Name ?? nameof(McpServer), - Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", - }; - - private readonly ITransport _sessionTransport; - private readonly bool _servicesScopePerRequest; - private readonly List _disposables = []; - - private readonly string _serverOnlyEndpointName; - private string? _endpointName; - private int _started; - - /// Holds a boxed value for the server. + /// + /// Gets the capabilities supported by the client. + /// /// - /// Initialized to non-null the first time SetLevel is used. This is stored as a strong box - /// rather than a nullable to be able to manipulate it atomically. + /// + /// These capabilities are established during the initialization handshake and indicate + /// which features the client supports, such as sampling, roots, and other + /// protocol-specific functionality. + /// + /// + /// Server implementations can check these capabilities to determine which features + /// are available when interacting with the client. + /// /// - private StrongBox? _loggingLevel; + public abstract ClientCapabilities? ClientCapabilities { get; } /// - /// Creates a new instance of . + /// Gets the version and implementation information of the connected client. /// - /// Transport to use for the server representing an already-established session. - /// Configuration options for this server, including capabilities. - /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. - /// Logger factory to use for logging - /// Optional service provider to use for dependency injection - /// The server was incorrectly configured. - public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) - : base(loggerFactory) - { - Throw.IfNull(transport); - Throw.IfNull(options); - - options ??= new(); - - _sessionTransport = transport; - ServerOptions = options; - Services = serviceProvider; - _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; - _servicesScopePerRequest = options.ScopeRequests; - - ClientInfo = options.KnownClientInfo; - UpdateEndpointNameWithClientInfo(); - - // Configure all request handlers based on the supplied options. - ServerCapabilities = new(); - ConfigureInitialize(options); - ConfigureTools(options); - ConfigurePrompts(options); - ConfigureResources(options); - ConfigureLogging(options); - ConfigureCompletion(options); - ConfigureExperimental(options); - ConfigurePing(); - - // Register any notification handlers that were provided. - if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) - { - NotificationHandlers.RegisterRange(notificationHandlers); - } - - // Now that everything has been configured, subscribe to any necessary notifications. - if (transport is not StreamableHttpServerTransport streamableHttpTransport || streamableHttpTransport.Stateless is false) - { - Register(ServerOptions.Capabilities?.Tools?.ToolCollection, NotificationMethods.ToolListChangedNotification); - Register(ServerOptions.Capabilities?.Prompts?.PromptCollection, NotificationMethods.PromptListChangedNotification); - Register(ServerOptions.Capabilities?.Resources?.ResourceCollection, NotificationMethods.ResourceListChangedNotification); - - void Register(McpServerPrimitiveCollection? collection, string notificationMethod) - where TPrimitive : IMcpServerPrimitive - { - if (collection is not null) - { - EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(notificationMethod); - collection.Changed += changed; - _disposables.Add(() => collection.Changed -= changed); - } - } - } - - // And initialize the session. - InitializeSession(transport); - } - - /// - public string? SessionId => _sessionTransport.SessionId; - - /// - public ServerCapabilities ServerCapabilities { get; } = new(); - - /// - public ClientCapabilities? ClientCapabilities { get; set; } - - /// - public Implementation? ClientInfo { get; set; } - - /// - public McpServerOptions ServerOptions { get; } - - /// - public IServiceProvider? Services { get; } + /// + /// + /// This property contains identification information about the client that has connected to this server, + /// including its name and version. This information is provided by the client during initialization. + /// + /// + /// Server implementations can use this information for logging, tracking client versions, + /// or implementing client-specific behaviors. + /// + /// + public abstract Implementation? ClientInfo { get; } - /// - public override string EndpointName => _endpointName ?? _serverOnlyEndpointName; + /// + /// Gets the options used to construct this server. + /// + /// + /// These options define the server's capabilities, protocol version, and other configuration + /// settings that were used to initialize the server. + /// + public abstract McpServerOptions ServerOptions { get; } - /// - public LoggingLevel? LoggingLevel => _loggingLevel?.Value; + /// + /// Gets the service provider for the server. + /// + public abstract IServiceProvider? Services { get; } - /// - public async Task RunAsync(CancellationToken cancellationToken = default) - { - if (Interlocked.Exchange(ref _started, 1) != 0) - { - throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once."); - } + /// Gets the last logging level set by the client, or if it's never been set. + public abstract LoggingLevel? LoggingLevel { get; } - try - { - StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken); - await MessageProcessingTask.ConfigureAwait(false); - } - finally - { - await DisposeAsync().ConfigureAwait(false); - } - } + /// + /// Runs the server, listening for and handling client requests. + /// + public abstract Task RunAsync(CancellationToken cancellationToken = default); - public override async ValueTask DisposeUnsynchronizedAsync() + /// + /// Creates a new instance of an . + /// + /// Transport to use for the server representing an already-established MCP session. + /// Configuration options for this server, including capabilities. + /// Logger factory to use for logging. If null, logging will be disabled. + /// Optional service provider to create new instances of tools and other dependencies. + /// An instance that should be disposed when no longer needed. + /// is . + /// is . + public static McpServer Create( + ITransport transport, + McpServerOptions serverOptions, + ILoggerFactory? loggerFactory = null, + IServiceProvider? serviceProvider = null) { - _disposables.ForEach(d => d()); - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); - } + Throw.IfNull(transport); + Throw.IfNull(serverOptions); - private void ConfigurePing() - { - SetHandler(RequestMethods.Ping, - async (request, _) => new PingResult(), - McpJsonUtilities.JsonContext.Default.JsonNode, - McpJsonUtilities.JsonContext.Default.PingResult); + return new McpServerImpl(transport, serverOptions, loggerFactory, serviceProvider); } - private void ConfigureInitialize(McpServerOptions options) + /// + /// Requests to sample an LLM via the client using the specified request parameters. + /// + /// The parameters for the sampling request. + /// The to monitor for cancellation requests. + /// A task containing the sampling result from the client. + /// The client does not support sampling. + public ValueTask SampleAsync( + CreateMessageRequestParams request, CancellationToken cancellationToken = default) { - RequestHandlers.Set(RequestMethods.Initialize, - async (request, _, _) => - { - ClientCapabilities = request?.Capabilities ?? new(); - ClientInfo = request?.ClientInfo; - - // Use the ClientInfo to update the session EndpointName for logging. - UpdateEndpointNameWithClientInfo(); - GetSessionOrThrow().EndpointName = EndpointName; - - // Negotiate a protocol version. If the server options provide one, use that. - // Otherwise, try to use whatever the client requested as long as it's supported. - // If it's not supported, fall back to the latest supported version. - string? protocolVersion = options.ProtocolVersion; - if (protocolVersion is null) - { - protocolVersion = request?.ProtocolVersion is string clientProtocolVersion && McpSession.SupportedProtocolVersions.Contains(clientProtocolVersion) ? - clientProtocolVersion : - McpSession.LatestProtocolVersion; - } - - return new InitializeResult - { - ProtocolVersion = protocolVersion, - Instructions = options.ServerInstructions, - ServerInfo = options.ServerInfo ?? DefaultImplementation, - Capabilities = ServerCapabilities ?? new(), - }; - }, - McpJsonUtilities.JsonContext.Default.InitializeRequestParams, - McpJsonUtilities.JsonContext.Default.InitializeResult); + ThrowIfSamplingUnsupported(); + + return SendRequestAsync( + RequestMethods.SamplingCreateMessage, + request, + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult, + cancellationToken: cancellationToken); } - private void ConfigureCompletion(McpServerOptions options) + /// + /// Requests to sample an LLM via the client using the provided chat messages and options. + /// + /// The messages to send as part of the request. + /// The options to use for the request, including model parameters and constraints. + /// The to monitor for cancellation requests. The default is . + /// A task containing the chat response from the model. + /// is . + /// The client does not support sampling. + public async Task SampleAsync( + IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) { - if (options.Capabilities?.Completions is not { } completionsCapability) - { - return; - } + Throw.IfNull(messages); - ServerCapabilities.Completions = new() - { - CompleteHandler = completionsCapability.CompleteHandler ?? (static async (_, __) => new CompleteResult()) - }; - - SetHandler( - RequestMethods.CompletionComplete, - ServerCapabilities.Completions.CompleteHandler, - McpJsonUtilities.JsonContext.Default.CompleteRequestParams, - McpJsonUtilities.JsonContext.Default.CompleteResult); - } - - private void ConfigureExperimental(McpServerOptions options) - { - ServerCapabilities.Experimental = options.Capabilities?.Experimental; - } + StringBuilder? systemPrompt = null; - private void ConfigureResources(McpServerOptions options) - { - if (options.Capabilities?.Resources is not { } resourcesCapability) + if (options?.Instructions is { } instructions) { - return; + (systemPrompt ??= new()).Append(instructions); } - ServerCapabilities.Resources = new(); - - var listResourcesHandler = resourcesCapability.ListResourcesHandler ?? (static async (_, __) => new ListResourcesResult()); - var listResourceTemplatesHandler = resourcesCapability.ListResourceTemplatesHandler ?? (static async (_, __) => new ListResourceTemplatesResult()); - var readResourceHandler = resourcesCapability.ReadResourceHandler ?? (static async (request, _) => throw new McpException($"Unknown resource URI: '{request.Params?.Uri}'", McpErrorCode.InvalidParams)); - var subscribeHandler = resourcesCapability.SubscribeToResourcesHandler ?? (static async (_, __) => new EmptyResult()); - var unsubscribeHandler = resourcesCapability.UnsubscribeFromResourcesHandler ?? (static async (_, __) => new EmptyResult()); - var resources = resourcesCapability.ResourceCollection; - var listChanged = resourcesCapability.ListChanged; - var subscribe = resourcesCapability.Subscribe; - - // Handle resources provided via DI. - if (resources is { IsEmpty: false }) + List samplingMessages = []; + foreach (var message in messages) { - var originalListResourcesHandler = listResourcesHandler; - listResourcesHandler = async (request, cancellationToken) => + if (message.Role == ChatRole.System) { - ListResourcesResult result = originalListResourcesHandler is not null ? - await originalListResourcesHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) + if (systemPrompt is null) { - foreach (var r in resources) - { - if (r.ProtocolResource is { } resource) - { - result.Resources.Add(resource); - } - } + systemPrompt = new(); + } + else + { + systemPrompt.AppendLine(); } - return result; - }; + systemPrompt.Append(message.Text); + continue; + } - var originalListResourceTemplatesHandler = listResourceTemplatesHandler; - listResourceTemplatesHandler = async (request, cancellationToken) => + if (message.Role == ChatRole.User || message.Role == ChatRole.Assistant) { - ListResourceTemplatesResult result = originalListResourceTemplatesHandler is not null ? - await originalListResourceTemplatesHandler(request, cancellationToken).ConfigureAwait(false) : - new(); + Role role = message.Role == ChatRole.User ? Role.User : Role.Assistant; - if (request.Params?.Cursor is null) + foreach (var content in message.Contents) { - foreach (var rt in resources) + switch (content) { - if (rt.IsTemplated) - { - result.ResourceTemplates.Add(rt.ProtocolResourceTemplate); - } + case TextContent textContent: + samplingMessages.Add(new() + { + Role = role, + Content = new TextContentBlock { Text = textContent.Text }, + }); + break; + + case DataContent dataContent when dataContent.HasTopLevelMediaType("image") || dataContent.HasTopLevelMediaType("audio"): + samplingMessages.Add(new() + { + Role = role, + Content = dataContent.HasTopLevelMediaType("image") ? + new ImageContentBlock + { + MimeType = dataContent.MediaType, + Data = dataContent.Base64Data.ToString(), + } : + new AudioContentBlock + { + MimeType = dataContent.MediaType, + Data = dataContent.Base64Data.ToString(), + }, + }); + break; } } + } + } - return result; - }; + ModelPreferences? modelPreferences = null; + if (options?.ModelId is { } modelId) + { + modelPreferences = new() { Hints = [new() { Name = modelId }] }; + } - // Synthesize read resource handler, which covers both resources and resource templates. - var originalReadResourceHandler = readResourceHandler; - readResourceHandler = async (request, cancellationToken) => + var result = await SampleAsync(new() { - if (request.Params?.Uri is string uri) - { - // First try an O(1) lookup by exact match. - if (resources.TryGetPrimitive(uri, out var resource)) - { - if (await resource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) - { - return result; - } - } + Messages = samplingMessages, + MaxTokens = options?.MaxOutputTokens, + StopSequences = options?.StopSequences?.ToArray(), + SystemPrompt = systemPrompt?.ToString(), + Temperature = options?.Temperature, + ModelPreferences = modelPreferences, + }, cancellationToken).ConfigureAwait(false); - // Fall back to an O(N) lookup, trying to match against each URI template. - // The number of templates is controlled by the server developer, and the number is expected to be - // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. - foreach (var resourceTemplate in resources) - { - if (await resourceTemplate.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) - { - return result; - } - } - } - - // Finally fall back to the handler. - return await originalReadResourceHandler(request, cancellationToken).ConfigureAwait(false); - }; + AIContent? responseContent = result.Content.ToAIContent(); - listChanged = true; - - // TODO: Implement subscribe/unsubscribe logic for resource and resource template collections. - // subscribe = true; - } + return new(new ChatMessage(result.Role is Role.User ? ChatRole.User : ChatRole.Assistant, responseContent is not null ? [responseContent] : [])) + { + ModelId = result.Model, + FinishReason = result.StopReason switch + { + "maxTokens" => ChatFinishReason.Length, + "endTurn" or "stopSequence" or _ => ChatFinishReason.Stop, + } + }; + } - ServerCapabilities.Resources.ListResourcesHandler = listResourcesHandler; - ServerCapabilities.Resources.ListResourceTemplatesHandler = listResourceTemplatesHandler; - ServerCapabilities.Resources.ReadResourceHandler = readResourceHandler; - ServerCapabilities.Resources.ResourceCollection = resources; - ServerCapabilities.Resources.SubscribeToResourcesHandler = subscribeHandler; - ServerCapabilities.Resources.UnsubscribeFromResourcesHandler = unsubscribeHandler; - ServerCapabilities.Resources.ListChanged = listChanged; - ServerCapabilities.Resources.Subscribe = subscribe; - - SetHandler( - RequestMethods.ResourcesList, - listResourcesHandler, - McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourcesResult); - - SetHandler( - RequestMethods.ResourcesTemplatesList, - listResourceTemplatesHandler, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); - - SetHandler( - RequestMethods.ResourcesRead, - readResourceHandler, - McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, - McpJsonUtilities.JsonContext.Default.ReadResourceResult); - - SetHandler( - RequestMethods.ResourcesSubscribe, - subscribeHandler, - McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult); - - SetHandler( - RequestMethods.ResourcesUnsubscribe, - unsubscribeHandler, - McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult); + /// + /// Creates an wrapper that can be used to send sampling requests to the client. + /// + /// The that can be used to issue sampling requests to the client. + /// The client does not support sampling. + public IChatClient AsSamplingChatClient() + { + ThrowIfSamplingUnsupported(); + return new SamplingChatClient(this); } - private void ConfigurePrompts(McpServerOptions options) + /// Gets an on which logged messages will be sent as notifications to the client. + /// An that can be used to log to the client.. + public ILoggerProvider AsClientLoggerProvider() { - if (options.Capabilities?.Prompts is not { } promptsCapability) - { - return; - } + return new ClientLoggerProvider(this); + } - ServerCapabilities.Prompts = new(); + /// + /// Requests the client to list the roots it exposes. + /// + /// The parameters for the list roots request. + /// The to monitor for cancellation requests. + /// A task containing the list of roots exposed by the client. + /// The client does not support roots. + public ValueTask RequestRootsAsync( + ListRootsRequestParams request, CancellationToken cancellationToken = default) + { + ThrowIfRootsUnsupported(); + + return SendRequestAsync( + RequestMethods.RootsList, + request, + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult, + cancellationToken: cancellationToken); + } - var listPromptsHandler = promptsCapability.ListPromptsHandler ?? (static async (_, __) => new ListPromptsResult()); - var getPromptHandler = promptsCapability.GetPromptHandler ?? (static async (request, _) => throw new McpException($"Unknown prompt: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); - var prompts = promptsCapability.PromptCollection; - var listChanged = promptsCapability.ListChanged; + /// + /// Requests additional information from the user via the client, allowing the server to elicit structured data. + /// + /// The parameters for the elicitation request. + /// The to monitor for cancellation requests. + /// A task containing the elicitation result. + /// The client does not support elicitation. + public ValueTask ElicitAsync( + ElicitRequestParams request, CancellationToken cancellationToken = default) + { + ThrowIfElicitationUnsupported(); + + return SendRequestAsync( + RequestMethods.ElicitationCreate, + request, + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult, + cancellationToken: cancellationToken); + } - // Handle tools provided via DI by augmenting the handlers to incorporate them. - if (prompts is { IsEmpty: false }) + private void ThrowIfSamplingUnsupported() + { + if (ClientCapabilities?.Sampling is null) { - var originalListPromptsHandler = listPromptsHandler; - listPromptsHandler = async (request, cancellationToken) => + if (ServerOptions.KnownClientInfo is not null) { - ListPromptsResult result = originalListPromptsHandler is not null ? - await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) - { - foreach (var p in prompts) - { - result.Prompts.Add(p.ProtocolPrompt); - } - } + throw new InvalidOperationException("Sampling is not supported in stateless mode."); + } - return result; - }; + throw new InvalidOperationException("Client does not support sampling."); + } + } - var originalGetPromptHandler = getPromptHandler; - getPromptHandler = (request, cancellationToken) => + private void ThrowIfRootsUnsupported() + { + if (ClientCapabilities?.Roots is null) + { + if (ServerOptions.KnownClientInfo is not null) { - if (request.Params is not null && - prompts.TryGetPrimitive(request.Params.Name, out var prompt)) - { - return prompt.GetAsync(request, cancellationToken); - } - - return originalGetPromptHandler(request, cancellationToken); - }; + throw new InvalidOperationException("Roots are not supported in stateless mode."); + } - listChanged = true; + throw new InvalidOperationException("Client does not support roots."); } - - ServerCapabilities.Prompts.ListPromptsHandler = listPromptsHandler; - ServerCapabilities.Prompts.GetPromptHandler = getPromptHandler; - ServerCapabilities.Prompts.PromptCollection = prompts; - ServerCapabilities.Prompts.ListChanged = listChanged; - - SetHandler( - RequestMethods.PromptsList, - listPromptsHandler, - McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, - McpJsonUtilities.JsonContext.Default.ListPromptsResult); - - SetHandler( - RequestMethods.PromptsGet, - getPromptHandler, - McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, - McpJsonUtilities.JsonContext.Default.GetPromptResult); } - private void ConfigureTools(McpServerOptions options) + private void ThrowIfElicitationUnsupported() { - if (options.Capabilities?.Tools is not { } toolsCapability) + if (ClientCapabilities?.Elicitation is null) { - return; + if (ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Elicitation is not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support elicitation requests."); } + } - ServerCapabilities.Tools = new(); + /// Provides an implementation that's implemented via client sampling. + private sealed class SamplingChatClient : IChatClient + { + private readonly McpServer _server; - var listToolsHandler = toolsCapability.ListToolsHandler ?? (static async (_, __) => new ListToolsResult()); - var callToolHandler = toolsCapability.CallToolHandler ?? (static async (request, _) => throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); - var tools = toolsCapability.ToolCollection; - var listChanged = toolsCapability.ListChanged; + public SamplingChatClient(McpServer server) => _server = server; - // Handle tools provided via DI by augmenting the handlers to incorporate them. - if (tools is { IsEmpty: false }) - { - var originalListToolsHandler = listToolsHandler; - listToolsHandler = async (request, cancellationToken) => - { - ListToolsResult result = originalListToolsHandler is not null ? - await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) : - new(); + /// + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + _server.SampleAsync(messages, options, cancellationToken); - if (request.Params?.Cursor is null) - { - foreach (var t in tools) - { - result.Tools.Add(t.ProtocolTool); - } - } - - return result; - }; - - var originalCallToolHandler = callToolHandler; - callToolHandler = (request, cancellationToken) => + /// + async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + var response = await GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + foreach (var update in response.ToChatResponseUpdates()) { - if (request.Params is not null && - tools.TryGetPrimitive(request.Params.Name, out var tool)) - { - return tool.InvokeAsync(request, cancellationToken); - } + yield return update; + } + } - return originalCallToolHandler(request, cancellationToken); - }; + /// + object? IChatClient.GetService(Type serviceType, object? serviceKey) + { + Throw.IfNull(serviceType); - listChanged = true; + return + serviceKey is not null ? null : + serviceType.IsInstanceOfType(this) ? this : + serviceType.IsInstanceOfType(_server) ? _server : + null; } - ServerCapabilities.Tools.ListToolsHandler = listToolsHandler; - ServerCapabilities.Tools.CallToolHandler = callToolHandler; - ServerCapabilities.Tools.ToolCollection = tools; - ServerCapabilities.Tools.ListChanged = listChanged; - - SetHandler( - RequestMethods.ToolsList, - listToolsHandler, - McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, - McpJsonUtilities.JsonContext.Default.ListToolsResult); - - SetHandler( - RequestMethods.ToolsCall, - callToolHandler, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult); + /// + void IDisposable.Dispose() { } // nop } - private void ConfigureLogging(McpServerOptions options) + /// + /// Provides an implementation for creating loggers + /// that send logging message notifications to the client for logged messages. + /// + private sealed class ClientLoggerProvider : ILoggerProvider { - // We don't require that the handler be provided, as we always store the provided log level to the server. - var setLoggingLevelHandler = options.Capabilities?.Logging?.SetLoggingLevelHandler; + private readonly McpServer _server; - ServerCapabilities.Logging = new(); - ServerCapabilities.Logging.SetLoggingLevelHandler = setLoggingLevelHandler; + public ClientLoggerProvider(McpServer server) => _server = server; - RequestHandlers.Set( - RequestMethods.LoggingSetLevel, - (request, destinationTransport, cancellationToken) => - { - // Store the provided level. - if (request is not null) - { - if (_loggingLevel is null) - { - Interlocked.CompareExchange(ref _loggingLevel, new(request.Level), null); - } - - _loggingLevel.Value = request.Level; - } + /// + public ILogger CreateLogger(string categoryName) + { + Throw.IfNull(categoryName); - // If a handler was provided, now delegate to it. - if (setLoggingLevelHandler is not null) - { - return InvokeHandlerAsync(setLoggingLevelHandler, request, destinationTransport, cancellationToken); - } + return new ClientLogger(_server, categoryName); + } - // Otherwise, consider it handled. - return new ValueTask(EmptyResult.Instance); - }, - McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult); - } + /// + void IDisposable.Dispose() { } - private ValueTask InvokeHandlerAsync( - Func, CancellationToken, ValueTask> handler, - TParams? args, - ITransport? destinationTransport = null, - CancellationToken cancellationToken = default) - { - return _servicesScopePerRequest ? - InvokeScopedAsync(handler, args, cancellationToken) : - handler(new(new DestinationBoundMcpServer(this, destinationTransport)) { Params = args }, cancellationToken); - - async ValueTask InvokeScopedAsync( - Func, CancellationToken, ValueTask> handler, - TParams? args, - CancellationToken cancellationToken) + private sealed class ClientLogger : ILogger { - var scope = Services?.GetService()?.CreateAsyncScope(); - try + private readonly McpServer _server; + private readonly string _categoryName; + + public ClientLogger(McpServer server, string categoryName) { - return await handler( - new RequestContext(new DestinationBoundMcpServer(this, destinationTransport)) - { - Services = scope?.ServiceProvider ?? Services, - Params = args - }, - cancellationToken).ConfigureAwait(false); + _server = server; + _categoryName = categoryName; } - finally + + /// + public IDisposable? BeginScope(TState state) where TState : notnull => + null; + + /// + public bool IsEnabled(LogLevel logLevel) => + _server?.LoggingLevel is { } loggingLevel && + McpServerImpl.ToLoggingLevel(logLevel) >= loggingLevel; + + /// + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) { - if (scope is not null) + if (!IsEnabled(logLevel)) { - await scope.Value.DisposeAsync().ConfigureAwait(false); + return; } - } - } - } - private void SetHandler( - string method, - Func, CancellationToken, ValueTask> handler, - JsonTypeInfo requestTypeInfo, - JsonTypeInfo responseTypeInfo) - { - RequestHandlers.Set(method, - (request, destinationTransport, cancellationToken) => - InvokeHandlerAsync(handler, request, destinationTransport, cancellationToken), - requestTypeInfo, responseTypeInfo); - } + Throw.IfNull(formatter); - private void UpdateEndpointNameWithClientInfo() - { - if (ClientInfo is null) - { - return; - } + LogInternal(logLevel, formatter(state, exception)); - _endpointName = $"{_serverOnlyEndpointName}, Client ({ClientInfo.Name} {ClientInfo.Version})"; + void LogInternal(LogLevel level, string message) + { + _ = _server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams + { + Level = McpServerImpl.ToLoggingLevel(level), + Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), + Logger = _categoryName, + }); + } + } + } } - - /// Maps a to a . - internal static LoggingLevel ToLoggingLevel(LogLevel level) => - level switch - { - LogLevel.Trace => Protocol.LoggingLevel.Debug, - LogLevel.Debug => Protocol.LoggingLevel.Debug, - LogLevel.Information => Protocol.LoggingLevel.Info, - LogLevel.Warning => Protocol.LoggingLevel.Warning, - LogLevel.Error => Protocol.LoggingLevel.Error, - LogLevel.Critical => Protocol.LoggingLevel.Critical, - _ => Protocol.LoggingLevel.Emergency, - }; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs index 277ed737..98edd3fc 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs @@ -1,9 +1,8 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; -using System.Text; -using System.Text.Json; namespace ModelContextProtocol.Server; @@ -26,19 +25,10 @@ public static class McpServerExtensions /// It allows detailed control over sampling parameters including messages, system prompt, temperature, /// and token limits. /// + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.SampleAsync)} instead.")] public static ValueTask SampleAsync( this IMcpServer server, CreateMessageRequestParams request, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - ThrowIfSamplingUnsupported(server); - - return server.SendRequestAsync( - RequestMethods.SamplingCreateMessage, - request, - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult, - cancellationToken: cancellationToken); - } + => AsServerOrThrow(server).SampleAsync(request, cancellationToken); /// /// Requests to sample an LLM via the client using the provided chat messages and options. @@ -55,104 +45,11 @@ public static ValueTask SampleAsync( /// This method converts the provided chat messages into a format suitable for the sampling API, /// handling different content types such as text, images, and audio. /// - public static async Task SampleAsync( + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.SampleAsync)} instead.")] + public static Task SampleAsync( this IMcpServer server, IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - Throw.IfNull(messages); - - StringBuilder? systemPrompt = null; - - if (options?.Instructions is { } instructions) - { - (systemPrompt ??= new()).Append(instructions); - } - - List samplingMessages = []; - foreach (var message in messages) - { - if (message.Role == ChatRole.System) - { - if (systemPrompt is null) - { - systemPrompt = new(); - } - else - { - systemPrompt.AppendLine(); - } - - systemPrompt.Append(message.Text); - continue; - } - - if (message.Role == ChatRole.User || message.Role == ChatRole.Assistant) - { - Role role = message.Role == ChatRole.User ? Role.User : Role.Assistant; - - foreach (var content in message.Contents) - { - switch (content) - { - case TextContent textContent: - samplingMessages.Add(new() - { - Role = role, - Content = new TextContentBlock { Text = textContent.Text }, - }); - break; - - case DataContent dataContent when dataContent.HasTopLevelMediaType("image") || dataContent.HasTopLevelMediaType("audio"): - samplingMessages.Add(new() - { - Role = role, - Content = dataContent.HasTopLevelMediaType("image") ? - new ImageContentBlock - { - MimeType = dataContent.MediaType, - Data = dataContent.Base64Data.ToString(), - } : - new AudioContentBlock - { - MimeType = dataContent.MediaType, - Data = dataContent.Base64Data.ToString(), - }, - }); - break; - } - } - } - } - - ModelPreferences? modelPreferences = null; - if (options?.ModelId is { } modelId) - { - modelPreferences = new() { Hints = [new() { Name = modelId }] }; - } - - var result = await server.SampleAsync(new() - { - Messages = samplingMessages, - MaxTokens = options?.MaxOutputTokens, - StopSequences = options?.StopSequences?.ToArray(), - SystemPrompt = systemPrompt?.ToString(), - Temperature = options?.Temperature, - ModelPreferences = modelPreferences, - }, cancellationToken).ConfigureAwait(false); - - AIContent? responseContent = result.Content.ToAIContent(); - - return new(new ChatMessage(result.Role is Role.User ? ChatRole.User : ChatRole.Assistant, responseContent is not null ? [responseContent] : [])) - { - ModelId = result.Model, - FinishReason = result.StopReason switch - { - "maxTokens" => ChatFinishReason.Length, - "endTurn" or "stopSequence" or _ => ChatFinishReason.Stop, - } - }; - } + => AsServerOrThrow(server).SampleAsync(messages, options, cancellationToken); /// /// Creates an wrapper that can be used to send sampling requests to the client. @@ -161,23 +58,16 @@ public static async Task SampleAsync( /// The that can be used to issue sampling requests to the client. /// is . /// The client does not support sampling. + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.AsSamplingChatClient)} instead.")] public static IChatClient AsSamplingChatClient(this IMcpServer server) - { - Throw.IfNull(server); - ThrowIfSamplingUnsupported(server); - - return new SamplingChatClient(server); - } + => AsServerOrThrow(server).AsSamplingChatClient(); /// Gets an on which logged messages will be sent as notifications to the client. /// The server to wrap as an . /// An that can be used to log to the client.. + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.AsSamplingChatClient)} instead.")] public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server) - { - Throw.IfNull(server); - - return new ClientLoggerProvider(server); - } + => AsServerOrThrow(server).AsClientLoggerProvider(); /// /// Requests the client to list the roots it exposes. @@ -194,19 +84,10 @@ public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server) /// navigated and accessed by the server. These resources might include file systems, databases, /// or other structured data sources that the client makes available through the protocol. /// + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.RequestRootsAsync)} instead.")] public static ValueTask RequestRootsAsync( this IMcpServer server, ListRootsRequestParams request, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - ThrowIfRootsUnsupported(server); - - return server.SendRequestAsync( - RequestMethods.RootsList, - request, - McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, - McpJsonUtilities.JsonContext.Default.ListRootsResult, - cancellationToken: cancellationToken); - } + => AsServerOrThrow(server).RequestRootsAsync(request, cancellationToken); /// /// Requests additional information from the user via the client, allowing the server to elicit structured data. @@ -220,143 +101,30 @@ public static ValueTask RequestRootsAsync( /// /// This method requires the client to support the elicitation capability. /// + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.ElicitAsync)} instead.")] public static ValueTask ElicitAsync( this IMcpServer server, ElicitRequestParams request, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - ThrowIfElicitationUnsupported(server); - - return server.SendRequestAsync( - RequestMethods.ElicitationCreate, - request, - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.ElicitResult, - cancellationToken: cancellationToken); - } - - private static void ThrowIfSamplingUnsupported(IMcpServer server) - { - if (server.ClientCapabilities?.Sampling is null) - { - if (server.ServerOptions.KnownClientInfo is not null) - { - throw new InvalidOperationException("Sampling is not supported in stateless mode."); - } + => AsServerOrThrow(server).ElicitAsync(request, cancellationToken); - throw new InvalidOperationException("Client does not support sampling."); - } - } - - private static void ThrowIfRootsUnsupported(IMcpServer server) + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#pragma warning disable CS0618 // Type or member is obsolete + private static McpServer AsServerOrThrow(IMcpServer server, [CallerMemberName] string memberName = "") +#pragma warning restore CS0618 // Type or member is obsolete { - if (server.ClientCapabilities?.Roots is null) + if (server is not McpServer mcpServer) { - if (server.ServerOptions.KnownClientInfo is not null) - { - throw new InvalidOperationException("Roots are not supported in stateless mode."); - } - - throw new InvalidOperationException("Client does not support roots."); + ThrowInvalidEndpointType(memberName); } - } - private static void ThrowIfElicitationUnsupported(IMcpServer server) - { - if (server.ClientCapabilities?.Elicitation is null) - { - if (server.ServerOptions.KnownClientInfo is not null) - { - throw new InvalidOperationException("Elicitation is not supported in stateless mode."); - } + return mcpServer; - throw new InvalidOperationException("Client does not support elicitation requests."); - } - } - - /// Provides an implementation that's implemented via client sampling. - private sealed class SamplingChatClient(IMcpServer server) : IChatClient - { - /// - public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => - server.SampleAsync(messages, options, cancellationToken); - - /// - async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( - IEnumerable messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) - { - var response = await GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - foreach (var update in response.ToChatResponseUpdates()) - { - yield return update; - } - } - - /// - object? IChatClient.GetService(Type serviceType, object? serviceKey) - { - Throw.IfNull(serviceType); - - return - serviceKey is not null ? null : - serviceType.IsInstanceOfType(this) ? this : - serviceType.IsInstanceOfType(server) ? server : - null; - } - - /// - void IDisposable.Dispose() { } // nop - } - - /// - /// Provides an implementation for creating loggers - /// that send logging message notifications to the client for logged messages. - /// - private sealed class ClientLoggerProvider(IMcpServer server) : ILoggerProvider - { - /// - public ILogger CreateLogger(string categoryName) - { - Throw.IfNull(categoryName); - - return new ClientLogger(server, categoryName); - } - - /// - void IDisposable.Dispose() { } - - private sealed class ClientLogger(IMcpServer server, string categoryName) : ILogger - { - /// - public IDisposable? BeginScope(TState state) where TState : notnull => - null; - - /// - public bool IsEnabled(LogLevel logLevel) => - server?.LoggingLevel is { } loggingLevel && - McpServer.ToLoggingLevel(logLevel) >= loggingLevel; - - /// - public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) - { - if (!IsEnabled(logLevel)) - { - return; - } - - Throw.IfNull(formatter); - - Log(logLevel, formatter(state, exception)); - - void Log(LogLevel logLevel, string message) - { - _ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams - { - Level = McpServer.ToLoggingLevel(logLevel), - Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), - Logger = categoryName, - }); - } - } - } + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowInvalidEndpointType(string memberName) + => throw new InvalidOperationException( + $"Only arguments assignable to '{nameof(McpServer)}' are supported. " + + $"Prefer using '{nameof(McpServer)}.{memberName}' instead, as " + + $"'{nameof(McpServerExtensions)}.{memberName}' is obsolete and will be " + + $"removed in the future."); } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerFactory.cs b/src/ModelContextProtocol.Core/Server/McpServerFactory.cs index 50d4188b..79384b7c 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerFactory.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerFactory.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Server; @@ -10,6 +10,7 @@ namespace ModelContextProtocol.Server; /// This is the recommended way to create instances. /// The factory handles proper initialization of server instances with the required dependencies. /// +[Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.Create)} instead.")] public static class McpServerFactory { /// @@ -27,10 +28,5 @@ public static IMcpServer Create( McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null, IServiceProvider? serviceProvider = null) - { - Throw.IfNull(transport); - Throw.IfNull(serverOptions); - - return new McpServer(transport, serverOptions, loggerFactory, serviceProvider); - } + => McpServer.Create(transport, serverOptions, loggerFactory, serviceProvider); } diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs new file mode 100644 index 00000000..1f83fc08 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -0,0 +1,626 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Runtime.CompilerServices; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.Server; + +/// +internal sealed class McpServerImpl : McpServer +{ + internal static Implementation DefaultImplementation { get; } = new() + { + Name = AssemblyNameHelper.DefaultAssemblyName.Name ?? nameof(McpServer), + Version = AssemblyNameHelper.DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + + private readonly ILogger _logger; + private readonly ITransport _sessionTransport; + private readonly bool _servicesScopePerRequest; + private readonly List _disposables = []; + private readonly NotificationHandlers _notificationHandlers; + private readonly RequestHandlers _requestHandlers; + private readonly McpSessionHandler _sessionHandler; + + private ClientCapabilities? _clientCapabilities; + private Implementation? _clientInfo; + + private readonly string _serverOnlyEndpointName; + private string _endpointName; + private int _started; + + private int _isDisposed; + + /// Holds a boxed value for the server. + /// + /// Initialized to non-null the first time SetLevel is used. This is stored as a strong box + /// rather than a nullable to be able to manipulate it atomically. + /// + private StrongBox? _loggingLevel; + + /// + /// Creates a new instance of . + /// + /// Transport to use for the server representing an already-established session. + /// Configuration options for this server, including capabilities. + /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// Logger factory to use for logging + /// Optional service provider to use for dependency injection + /// The server was incorrectly configured. + public McpServerImpl(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) + { + Throw.IfNull(transport); + Throw.IfNull(options); + + options ??= new(); + + _sessionTransport = transport; + ServerOptions = options; + Services = serviceProvider; + _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; + _endpointName = _serverOnlyEndpointName; + _servicesScopePerRequest = options.ScopeRequests; + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + _clientInfo = options.KnownClientInfo; + UpdateEndpointNameWithClientInfo(); + + _notificationHandlers = new(); + _requestHandlers = []; + + // Configure all request handlers based on the supplied options. + ServerCapabilities = new(); + ConfigureInitialize(options); + ConfigureTools(options); + ConfigurePrompts(options); + ConfigureResources(options); + ConfigureLogging(options); + ConfigureCompletion(options); + ConfigureExperimental(options); + ConfigurePing(); + + // Register any notification handlers that were provided. + if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) + { + _notificationHandlers.RegisterRange(notificationHandlers); + } + + // Now that everything has been configured, subscribe to any necessary notifications. + if (transport is not StreamableHttpServerTransport streamableHttpTransport || streamableHttpTransport.Stateless is false) + { + Register(ServerOptions.Capabilities?.Tools?.ToolCollection, NotificationMethods.ToolListChangedNotification); + Register(ServerOptions.Capabilities?.Prompts?.PromptCollection, NotificationMethods.PromptListChangedNotification); + Register(ServerOptions.Capabilities?.Resources?.ResourceCollection, NotificationMethods.ResourceListChangedNotification); + + void Register(McpServerPrimitiveCollection? collection, string notificationMethod) + where TPrimitive : IMcpServerPrimitive + { + if (collection is not null) + { + EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(notificationMethod); + collection.Changed += changed; + _disposables.Add(() => collection.Changed -= changed); + } + } + } + + // And initialize the session. + _sessionHandler = new McpSessionHandler(isServer: true, _sessionTransport, _endpointName!, _requestHandlers, _notificationHandlers, _logger); + } + + /// + public override string? SessionId => _sessionTransport.SessionId; + + /// + public ServerCapabilities ServerCapabilities { get; } = new(); + + /// + public override ClientCapabilities? ClientCapabilities => _clientCapabilities; + + /// + public override Implementation? ClientInfo => _clientInfo; + + /// + public override McpServerOptions ServerOptions { get; } + + /// + public override IServiceProvider? Services { get; } + + /// + public override LoggingLevel? LoggingLevel => _loggingLevel?.Value; + + /// + public override async Task RunAsync(CancellationToken cancellationToken = default) + { + if (Interlocked.Exchange(ref _started, 1) != 0) + { + throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once."); + } + + try + { + await _sessionHandler.ProcessMessagesAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + await DisposeAsync().ConfigureAwait(false); + } + } + + + /// + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + => _sessionHandler.SendRequestAsync(request, cancellationToken); + + /// + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + => _sessionHandler.SendMessageAsync(message, cancellationToken); + + /// + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + => _sessionHandler.RegisterNotificationHandler(method, handler); + + /// + public override async ValueTask DisposeAsync() + { + if (Interlocked.CompareExchange(ref _isDisposed, 1, 0) != 0) + { + return; + } + + _disposables.ForEach(d => d()); + await _sessionHandler.DisposeAsync().ConfigureAwait(false); + } + + private void ConfigurePing() + { + SetHandler(RequestMethods.Ping, + async (request, _) => new PingResult(), + McpJsonUtilities.JsonContext.Default.JsonNode, + McpJsonUtilities.JsonContext.Default.PingResult); + } + + private void ConfigureInitialize(McpServerOptions options) + { + _requestHandlers.Set(RequestMethods.Initialize, + async (request, _, _) => + { + _clientCapabilities = request?.Capabilities ?? new(); + _clientInfo = request?.ClientInfo; + + // Use the ClientInfo to update the session EndpointName for logging. + UpdateEndpointNameWithClientInfo(); + _sessionHandler.EndpointName = _endpointName; + + // Negotiate a protocol version. If the server options provide one, use that. + // Otherwise, try to use whatever the client requested as long as it's supported. + // If it's not supported, fall back to the latest supported version. + string? protocolVersion = options.ProtocolVersion; + if (protocolVersion is null) + { + protocolVersion = request?.ProtocolVersion is string clientProtocolVersion && McpSessionHandler.SupportedProtocolVersions.Contains(clientProtocolVersion) ? + clientProtocolVersion : + McpSessionHandler.LatestProtocolVersion; + } + + return new InitializeResult + { + ProtocolVersion = protocolVersion, + Instructions = options.ServerInstructions, + ServerInfo = options.ServerInfo ?? DefaultImplementation, + Capabilities = ServerCapabilities ?? new(), + }; + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult); + } + + private void ConfigureCompletion(McpServerOptions options) + { + if (options.Capabilities?.Completions is not { } completionsCapability) + { + return; + } + + ServerCapabilities.Completions = new() + { + CompleteHandler = completionsCapability.CompleteHandler ?? (static async (_, __) => new CompleteResult()) + }; + + SetHandler( + RequestMethods.CompletionComplete, + ServerCapabilities.Completions.CompleteHandler, + McpJsonUtilities.JsonContext.Default.CompleteRequestParams, + McpJsonUtilities.JsonContext.Default.CompleteResult); + } + + private void ConfigureExperimental(McpServerOptions options) + { + ServerCapabilities.Experimental = options.Capabilities?.Experimental; + } + + private void ConfigureResources(McpServerOptions options) + { + if (options.Capabilities?.Resources is not { } resourcesCapability) + { + return; + } + + ServerCapabilities.Resources = new(); + + var listResourcesHandler = resourcesCapability.ListResourcesHandler ?? (static async (_, __) => new ListResourcesResult()); + var listResourceTemplatesHandler = resourcesCapability.ListResourceTemplatesHandler ?? (static async (_, __) => new ListResourceTemplatesResult()); + var readResourceHandler = resourcesCapability.ReadResourceHandler ?? (static async (request, _) => throw new McpException($"Unknown resource URI: '{request.Params?.Uri}'", McpErrorCode.InvalidParams)); + var subscribeHandler = resourcesCapability.SubscribeToResourcesHandler ?? (static async (_, __) => new EmptyResult()); + var unsubscribeHandler = resourcesCapability.UnsubscribeFromResourcesHandler ?? (static async (_, __) => new EmptyResult()); + var resources = resourcesCapability.ResourceCollection; + var listChanged = resourcesCapability.ListChanged; + var subscribe = resourcesCapability.Subscribe; + + // Handle resources provided via DI. + if (resources is { IsEmpty: false }) + { + var originalListResourcesHandler = listResourcesHandler; + listResourcesHandler = async (request, cancellationToken) => + { + ListResourcesResult result = originalListResourcesHandler is not null ? + await originalListResourcesHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var r in resources) + { + if (r.ProtocolResource is { } resource) + { + result.Resources.Add(resource); + } + } + } + + return result; + }; + + var originalListResourceTemplatesHandler = listResourceTemplatesHandler; + listResourceTemplatesHandler = async (request, cancellationToken) => + { + ListResourceTemplatesResult result = originalListResourceTemplatesHandler is not null ? + await originalListResourceTemplatesHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var rt in resources) + { + if (rt.IsTemplated) + { + result.ResourceTemplates.Add(rt.ProtocolResourceTemplate); + } + } + } + + return result; + }; + + // Synthesize read resource handler, which covers both resources and resource templates. + var originalReadResourceHandler = readResourceHandler; + readResourceHandler = async (request, cancellationToken) => + { + if (request.Params?.Uri is string uri) + { + // First try an O(1) lookup by exact match. + if (resources.TryGetPrimitive(uri, out var resource)) + { + if (await resource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) + { + return result; + } + } + + // Fall back to an O(N) lookup, trying to match against each URI template. + // The number of templates is controlled by the server developer, and the number is expected to be + // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. + foreach (var resourceTemplate in resources) + { + if (await resourceTemplate.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) + { + return result; + } + } + } + + // Finally fall back to the handler. + return await originalReadResourceHandler(request, cancellationToken).ConfigureAwait(false); + }; + + listChanged = true; + + // TODO: Implement subscribe/unsubscribe logic for resource and resource template collections. + // subscribe = true; + } + + ServerCapabilities.Resources.ListResourcesHandler = listResourcesHandler; + ServerCapabilities.Resources.ListResourceTemplatesHandler = listResourceTemplatesHandler; + ServerCapabilities.Resources.ReadResourceHandler = readResourceHandler; + ServerCapabilities.Resources.ResourceCollection = resources; + ServerCapabilities.Resources.SubscribeToResourcesHandler = subscribeHandler; + ServerCapabilities.Resources.UnsubscribeFromResourcesHandler = unsubscribeHandler; + ServerCapabilities.Resources.ListChanged = listChanged; + ServerCapabilities.Resources.Subscribe = subscribe; + + SetHandler( + RequestMethods.ResourcesList, + listResourcesHandler, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult); + + SetHandler( + RequestMethods.ResourcesTemplatesList, + listResourceTemplatesHandler, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); + + SetHandler( + RequestMethods.ResourcesRead, + readResourceHandler, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult); + + SetHandler( + RequestMethods.ResourcesSubscribe, + subscribeHandler, + McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + + SetHandler( + RequestMethods.ResourcesUnsubscribe, + unsubscribeHandler, + McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + } + + private void ConfigurePrompts(McpServerOptions options) + { + if (options.Capabilities?.Prompts is not { } promptsCapability) + { + return; + } + + ServerCapabilities.Prompts = new(); + + var listPromptsHandler = promptsCapability.ListPromptsHandler ?? (static async (_, __) => new ListPromptsResult()); + var getPromptHandler = promptsCapability.GetPromptHandler ?? (static async (request, _) => throw new McpException($"Unknown prompt: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); + var prompts = promptsCapability.PromptCollection; + var listChanged = promptsCapability.ListChanged; + + // Handle tools provided via DI by augmenting the handlers to incorporate them. + if (prompts is { IsEmpty: false }) + { + var originalListPromptsHandler = listPromptsHandler; + listPromptsHandler = async (request, cancellationToken) => + { + ListPromptsResult result = originalListPromptsHandler is not null ? + await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var p in prompts) + { + result.Prompts.Add(p.ProtocolPrompt); + } + } + + return result; + }; + + var originalGetPromptHandler = getPromptHandler; + getPromptHandler = (request, cancellationToken) => + { + if (request.Params is not null && + prompts.TryGetPrimitive(request.Params.Name, out var prompt)) + { + return prompt.GetAsync(request, cancellationToken); + } + + return originalGetPromptHandler(request, cancellationToken); + }; + + listChanged = true; + } + + ServerCapabilities.Prompts.ListPromptsHandler = listPromptsHandler; + ServerCapabilities.Prompts.GetPromptHandler = getPromptHandler; + ServerCapabilities.Prompts.PromptCollection = prompts; + ServerCapabilities.Prompts.ListChanged = listChanged; + + SetHandler( + RequestMethods.PromptsList, + listPromptsHandler, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult); + + SetHandler( + RequestMethods.PromptsGet, + getPromptHandler, + McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, + McpJsonUtilities.JsonContext.Default.GetPromptResult); + } + + private void ConfigureTools(McpServerOptions options) + { + if (options.Capabilities?.Tools is not { } toolsCapability) + { + return; + } + + ServerCapabilities.Tools = new(); + + var listToolsHandler = toolsCapability.ListToolsHandler ?? (static async (_, __) => new ListToolsResult()); + var callToolHandler = toolsCapability.CallToolHandler ?? (static async (request, _) => throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); + var tools = toolsCapability.ToolCollection; + var listChanged = toolsCapability.ListChanged; + + // Handle tools provided via DI by augmenting the handlers to incorporate them. + if (tools is { IsEmpty: false }) + { + var originalListToolsHandler = listToolsHandler; + listToolsHandler = async (request, cancellationToken) => + { + ListToolsResult result = originalListToolsHandler is not null ? + await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var t in tools) + { + result.Tools.Add(t.ProtocolTool); + } + } + + return result; + }; + + var originalCallToolHandler = callToolHandler; + callToolHandler = (request, cancellationToken) => + { + if (request.Params is not null && + tools.TryGetPrimitive(request.Params.Name, out var tool)) + { + return tool.InvokeAsync(request, cancellationToken); + } + + return originalCallToolHandler(request, cancellationToken); + }; + + listChanged = true; + } + + ServerCapabilities.Tools.ListToolsHandler = listToolsHandler; + ServerCapabilities.Tools.CallToolHandler = callToolHandler; + ServerCapabilities.Tools.ToolCollection = tools; + ServerCapabilities.Tools.ListChanged = listChanged; + + SetHandler( + RequestMethods.ToolsList, + listToolsHandler, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult); + + SetHandler( + RequestMethods.ToolsCall, + callToolHandler, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResult); + } + + private void ConfigureLogging(McpServerOptions options) + { + // We don't require that the handler be provided, as we always store the provided log level to the server. + var setLoggingLevelHandler = options.Capabilities?.Logging?.SetLoggingLevelHandler; + + ServerCapabilities.Logging = new(); + ServerCapabilities.Logging.SetLoggingLevelHandler = setLoggingLevelHandler; + + _requestHandlers.Set( + RequestMethods.LoggingSetLevel, + (request, destinationTransport, cancellationToken) => + { + // Store the provided level. + if (request is not null) + { + if (_loggingLevel is null) + { + Interlocked.CompareExchange(ref _loggingLevel, new(request.Level), null); + } + + _loggingLevel.Value = request.Level; + } + + // If a handler was provided, now delegate to it. + if (setLoggingLevelHandler is not null) + { + return InvokeHandlerAsync(setLoggingLevelHandler, request, destinationTransport, cancellationToken); + } + + // Otherwise, consider it handled. + return new ValueTask(EmptyResult.Instance); + }, + McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + } + + private ValueTask InvokeHandlerAsync( + Func, CancellationToken, ValueTask> handler, + TParams? args, + ITransport? destinationTransport = null, + CancellationToken cancellationToken = default) + { + return _servicesScopePerRequest ? + InvokeScopedAsync(handler, args, cancellationToken) : + handler(new(new DestinationBoundMcpServer(this, destinationTransport)) { Params = args }, cancellationToken); + + async ValueTask InvokeScopedAsync( + Func, CancellationToken, ValueTask> handler, + TParams? args, + CancellationToken cancellationToken) + { + var scope = Services?.GetService()?.CreateAsyncScope(); + try + { + return await handler( + new RequestContext(new DestinationBoundMcpServer(this, destinationTransport)) + { + Services = scope?.ServiceProvider ?? Services, + Params = args + }, + cancellationToken).ConfigureAwait(false); + } + finally + { + if (scope is not null) + { + await scope.Value.DisposeAsync().ConfigureAwait(false); + } + } + } + } + + private void SetHandler( + string method, + Func, CancellationToken, ValueTask> handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo) + { + _requestHandlers.Set(method, + (request, destinationTransport, cancellationToken) => + InvokeHandlerAsync(handler, request, destinationTransport, cancellationToken), + requestTypeInfo, responseTypeInfo); + } + + private void UpdateEndpointNameWithClientInfo() + { + if (ClientInfo is null) + { + return; + } + + _endpointName = $"{_serverOnlyEndpointName}, Client ({ClientInfo.Name} {ClientInfo.Version})"; + } + + /// Maps a to a . + internal static LoggingLevel ToLoggingLevel(LogLevel level) => + level switch + { + LogLevel.Trace => Protocol.LoggingLevel.Debug, + LogLevel.Debug => Protocol.LoggingLevel.Debug, + LogLevel.Information => Protocol.LoggingLevel.Info, + LogLevel.Warning => Protocol.LoggingLevel.Warning, + LogLevel.Error => Protocol.LoggingLevel.Error, + LogLevel.Critical => Protocol.LoggingLevel.Critical, + _ => Protocol.LoggingLevel.Emergency, + }; +} diff --git a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs index 68874df3..41d6b1cb 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs @@ -15,8 +15,8 @@ namespace ModelContextProtocol.Server; /// is an abstract base class that represents an MCP prompt for use in the server (as opposed /// to , which provides the protocol representation of a prompt, and , which /// provides a client-side representation of a prompt). Instances of can be added into a -/// to be picked up automatically when is used to create -/// an , or added into a . +/// to be picked up automatically when is used to create +/// an , or added into a . /// /// /// Most commonly, instances are created using the static methods. @@ -34,7 +34,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -45,7 +45,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// @@ -201,7 +201,7 @@ public static McpServerPrompt Create( /// is . /// /// Unlike the other overloads of Create, the created by - /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// does not provide all of the special parameter handling for MCP-specific concepts, like . /// public static McpServerPrompt Create( AIFunction function, diff --git a/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs index c71e969d..ac9e247f 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs @@ -25,7 +25,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -36,7 +36,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerResource.cs b/src/ModelContextProtocol.Core/Server/McpServerResource.cs index 8e42d3e1..fe36b284 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResource.cs @@ -13,7 +13,7 @@ namespace ModelContextProtocol.Server; /// is an abstract base class that represents an MCP resource for use in the server (as opposed /// to or , which provide the protocol representations of a resource). Instances of /// can be added into a to be picked up automatically when -/// is used to create an , or added into a . +/// is used to create an , or added into a . /// /// /// Most commonly, instances are created using the static methods. @@ -35,7 +35,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -46,7 +46,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// @@ -223,7 +223,7 @@ public static McpServerResource Create( /// is . /// /// Unlike the other overloads of Create, the created by - /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// does not provide all of the special parameter handling for MCP-specific concepts, like . /// public static McpServerResource Create( AIFunction function, diff --git a/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs index bc2f138f..66c593e4 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs @@ -23,7 +23,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -34,7 +34,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerTool.cs b/src/ModelContextProtocol.Core/Server/McpServerTool.cs index e3958271..76391a50 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerTool.cs @@ -15,8 +15,8 @@ namespace ModelContextProtocol.Server; /// is an abstract base class that represents an MCP tool for use in the server (as opposed /// to , which provides the protocol representation of a tool, and , which /// provides a client-side representation of a tool). Instances of can be added into a -/// to be picked up automatically when is used to create -/// an , or added into a . +/// to be picked up automatically when is used to create +/// an , or added into a . /// /// /// Most commonly, instances are created using the static methods. @@ -35,7 +35,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . The parameter is not included in the generated JSON schema. /// /// @@ -47,7 +47,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are not included in the JSON schema and are bound directly to the +/// parameters are not included in the JSON schema and are bound directly to the /// instance associated with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// @@ -203,7 +203,7 @@ public static McpServerTool Create( /// is . /// /// Unlike the other overloads of Create, the created by - /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// does not provide all of the special parameter handling for MCP-specific concepts, like . /// public static McpServerTool Create( AIFunction function, diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs index d4ea9eb7..7d5bf488 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs @@ -26,7 +26,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . The parameter is not included in the generated JSON schema. /// /// @@ -38,7 +38,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are not included in the JSON schema and are bound directly to the +/// parameters are not included in the JSON schema and are bound directly to the /// instance associated with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs index b0ea9d99..37d24b98 100644 --- a/src/ModelContextProtocol.Core/Server/RequestContext.cs +++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs @@ -12,13 +12,13 @@ namespace ModelContextProtocol.Server; public sealed class RequestContext { /// The server with which this instance is associated. - private IMcpServer _server; + private McpServer _server; /// /// Initializes a new instance of the class with the specified server. /// /// The server with which this instance is associated. - public RequestContext(IMcpServer server) + public RequestContext(McpServer server) { Throw.IfNull(server); @@ -27,7 +27,7 @@ public RequestContext(IMcpServer server) } /// Gets or sets the server with which this instance is associated. - public IMcpServer Server + public McpServer Server { get => _server; set @@ -39,10 +39,10 @@ public IMcpServer Server /// Gets or sets the services associated with this request. /// - /// This may not be the same instance stored in + /// This may not be the same instance stored in /// if was true, in which case this /// might be a scoped derived from the server's - /// . + /// . /// public IServiceProvider? Services { get; set; } diff --git a/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs b/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs index 556a3115..307c180a 100644 --- a/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs @@ -37,7 +37,7 @@ private static string GetServerName(McpServerOptions serverOptions) { Throw.IfNull(serverOptions); - return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name; + return serverOptions.ServerInfo?.Name ?? McpServerImpl.DefaultImplementation.Name; } // Neither WindowsConsoleStream nor UnixConsoleStream respect CancellationTokens or cancel any I/O on Dispose. diff --git a/src/ModelContextProtocol.Core/TokenProgress.cs b/src/ModelContextProtocol.Core/TokenProgress.cs index f222fbf7..6b7a91e0 100644 --- a/src/ModelContextProtocol.Core/TokenProgress.cs +++ b/src/ModelContextProtocol.Core/TokenProgress.cs @@ -4,13 +4,13 @@ namespace ModelContextProtocol; /// /// Provides an tied to a specific progress token and that will issue -/// progress notifications on the supplied endpoint. +/// progress notifications on the supplied session. /// -internal sealed class TokenProgress(IMcpEndpoint endpoint, ProgressToken progressToken) : IProgress +internal sealed class TokenProgress(McpSession session, ProgressToken progressToken) : IProgress { /// public void Report(ProgressNotificationValue value) { - _ = endpoint.NotifyProgressAsync(progressToken, value, CancellationToken.None); + _ = session.NotifyProgressAsync(progressToken, value, CancellationToken.None); } } diff --git a/src/ModelContextProtocol/IMcpServerBuilder.cs b/src/ModelContextProtocol/IMcpServerBuilder.cs index 5ec37eba..016e9eb3 100644 --- a/src/ModelContextProtocol/IMcpServerBuilder.cs +++ b/src/ModelContextProtocol/IMcpServerBuilder.cs @@ -3,7 +3,7 @@ namespace Microsoft.Extensions.DependencyInjection; /// -/// Provides a builder for configuring instances. +/// Provides a builder for configuring instances. /// /// /// diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index 2d6314ba..faca02bc 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -827,8 +827,8 @@ public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpSer /// and may begin sending log messages at or above the specified level to the client. /// /// - /// Regardless of whether a handler is provided, an should itself handle - /// such notifications by updating its property to return the + /// Regardless of whether a handler is provided, an should itself handle + /// such notifications by updating its property to return the /// most recently set level. /// /// @@ -908,7 +908,7 @@ private static void AddSingleSessionServerDependencies(IServiceCollection servic ITransport serverTransport = services.GetRequiredService(); IOptions options = services.GetRequiredService>(); ILoggerFactory? loggerFactory = services.GetService(); - return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services); + return McpServer.Create(serverTransport, options.Value, loggerFactory, services); }); } #endregion diff --git a/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs b/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs index b50e4614..80e8216a 100644 --- a/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs +++ b/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs @@ -10,7 +10,7 @@ namespace ModelContextProtocol; /// /// The host's application lifetime. If available, it will have termination requested when the session's run completes. /// -internal sealed class SingleSessionMcpServerHostedService(IMcpServer session, IHostApplicationLifetime? lifetime = null) : BackgroundService +internal sealed class SingleSessionMcpServerHostedService(McpServer session, IHostApplicationLifetime? lifetime = null) : BackgroundService { /// protected override async Task ExecuteAsync(CancellationToken stoppingToken) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs index efff68c8..09e14ccb 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs @@ -122,7 +122,7 @@ public async Task CanAuthenticate_WithResourceMetadataFromEvent() LoggerFactory ); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken @@ -167,7 +167,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() LoggerFactory ); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs index b480934a..db82f673 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs @@ -109,7 +109,7 @@ public async Task CanAuthenticate() }, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -129,7 +129,7 @@ public async Task CannotAuthenticate_WithoutOAuthConfiguration() Endpoint = new(McpServerUrl), }, HttpClient, LoggerFactory); - var httpEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + var httpEx = await Assert.ThrowsAsync(async () => await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.Equal(HttpStatusCode.Unauthorized, httpEx.StatusCode); @@ -159,7 +159,7 @@ public async Task CannotAuthenticate_WithUnregisteredClient() }, HttpClient, LoggerFactory); // The EqualException is thrown by HandleAuthorizationUrlAsync when the /authorize request gets a 400 - var equalEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + var equalEx = await Assert.ThrowsAsync(async () => await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } @@ -190,7 +190,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration() }, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -219,7 +219,7 @@ public async Task CanAuthenticate_WithTokenRefresh() // The test-refresh-client should get an expired token first, // then automatically refresh it to get a working token - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); Assert.True(_testOAuthServer.HasIssuedRefreshToken); @@ -252,7 +252,7 @@ public async Task CanAuthenticate_WithExtraParams() }, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(_lastAuthorizationUri?.Query); @@ -286,7 +286,7 @@ public async Task CannotOverrideExistingParameters_WithExtraParams() }, }, HttpClient, LoggerFactory); - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync( + await Assert.ThrowsAsync(() => McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 9b3c91b9..30e0f9df 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -23,7 +23,7 @@ public override void Dispose() protected abstract SseClientTransportOptions ClientTransportOptions { get; } - private Task GetClientAsync(McpClientOptions? options = null) + private Task GetClientAsync(McpClientOptions? options = null) { return _fixture.ConnectMcpClientAsync(options, LoggerFactory); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 4d0d7356..cce23a53 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -23,7 +23,7 @@ protected void ConfigureStateless(HttpServerTransportOptions options) options.Stateless = Stateless; } - protected async Task ConnectAsync( + protected async Task ConnectAsync( string? path = null, SseClientTransportOptions? transportOptions = null, McpClientOptions? clientOptions = null) @@ -37,7 +37,7 @@ protected async Task ConnectAsync( TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, }, HttpClient, LoggerFactory); - return await McpClientFactory.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken); + return await McpClient.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken); } [Fact] @@ -204,7 +204,7 @@ public string EchoWithUserName(string message) private class SamplingRegressionTools { [McpServerTool(Name = "sampling-tool")] - public static async Task SamplingToolAsync(IMcpServer server, string prompt, CancellationToken cancellationToken) + public static async Task SamplingToolAsync(McpServer server, string prompt, CancellationToken cancellationToken) { // This tool reproduces the scenario described in https://github.com/modelcontextprotocol/csharp-sdk/issues/464 // 1. The client calls tool with request ID 2, because it's the first request after the initialize request. diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 8191f609..c419ec69 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -21,8 +21,8 @@ public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : Kestr Name = "In-memory SSE Client", }; - private Task ConnectMcpClientAsync(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null) - => McpClientFactory.CreateAsync( + private Task ConnectMcpClientAsync(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null) + => McpClient.CreateAsync( new SseClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -257,7 +257,7 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints, b try { var transportTask = transport.RunAsync(cancellationToken: requestAborted); - await using var server = McpServerFactory.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider); + await using var server = McpServer.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider); try { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index 2aa675c8..7a11bebc 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -44,9 +44,9 @@ public SseServerIntegrationTestFixture() public HttpClient HttpClient { get; } - public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) + public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) { - return McpClientFactory.CreateAsync( + return McpClient.CreateAsync( new SseClientTransport(DefaultTransportOptions, HttpClient, loggerFactory), options, loggerFactory, diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index b50a43ed..8bc2130d 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -58,8 +58,8 @@ private async Task StartAsync() HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); } - private Task ConnectMcpClientAsync(McpClientOptions? clientOptions = null) - => McpClientFactory.CreateAsync( + private Task ConnectMcpClientAsync(McpClientOptions? clientOptions = null) + => McpClient.CreateAsync( new SseClientTransport(DefaultTransportOptions, HttpClient, LoggerFactory), clientOptions, LoggerFactory, TestContext.Current.CancellationToken); @@ -194,7 +194,7 @@ public async Task ScopedServices_Resolve_FromRequestScope() } [McpServerTool(Name = "testSamplingErrors")] - public static async Task TestSamplingErrors(IMcpServer server) + public static async Task TestSamplingErrors(McpServer server) { const string expectedSamplingErrorMessage = "Sampling is not supported in stateless mode."; @@ -212,7 +212,7 @@ public static async Task TestSamplingErrors(IMcpServer server) } [McpServerTool(Name = "testRootsErrors")] - public static async Task TestRootsErrors(IMcpServer server) + public static async Task TestRootsErrors(McpServer server) { const string expectedRootsErrorMessage = "Roots are not supported in stateless mode."; @@ -227,7 +227,7 @@ public static async Task TestRootsErrors(IMcpServer server) } [McpServerTool(Name = "testElicitationErrors")] - public static async Task TestElicitationErrors(IMcpServer server) + public static async Task TestElicitationErrors(McpServer server) { const string expectedElicitationErrorMessage = "Elicitation is not supported in stateless mode."; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index 7ce3516e..3ca8010a 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -118,7 +118,7 @@ public async Task CanCallToolOnSessionlessStreamableHttpServer() TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); var echoTool = Assert.Single(tools); @@ -138,7 +138,7 @@ public async Task CanCallToolConcurrently() TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); var echoTool = Assert.Single(tools); @@ -164,7 +164,7 @@ public async Task SendsDeleteRequestOnDispose() TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // Dispose should trigger DELETE request await client.DisposeAsync(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index bb184034..7b2be8f9 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -252,7 +252,7 @@ public async Task MultipleConcurrentJsonRpcRequests_IsHandled_InParallel() [Fact] public async Task GetRequest_Receives_UnsolicitedNotifications() { - IMcpServer? server = null; + McpServer? server = null; Builder.Services.AddMcpServer() .WithHttpTransport(options => diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 0bc4134f..743c858c 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -51,7 +51,7 @@ private static async Task Main(string[] args) using var loggerFactory = CreateLoggerFactory(); await using var stdioTransport = new StdioServerTransport("TestServer", loggerFactory); - await using IMcpServer server = McpServerFactory.Create(stdioTransport, options, loggerFactory); + await using McpServer server = McpServer.Create(stdioTransport, options, loggerFactory); Log.Logger.Information("Server running..."); @@ -61,7 +61,7 @@ private static async Task Main(string[] args) await server.RunAsync(); } - private static async Task RunBackgroundLoop(IMcpServer server, CancellationToken cancellationToken = default) + private static async Task RunBackgroundLoop(McpServer server, CancellationToken cancellationToken = default) { var loggingLevels = (LoggingLevel[])Enum.GetValues(typeof(LoggingLevel)); var random = new Random(); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs similarity index 90% rename from tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs rename to tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs index 7516a218..15127502 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs @@ -1,26 +1,24 @@ -using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; -using Moq; using System.IO.Pipelines; using System.Text.Json; using System.Threading.Channels; namespace ModelContextProtocol.Tests.Client; -public class McpClientFactoryTests +public class McpClientCreationTests { [Fact] public async Task CreateAsync_WithInvalidArgs_Throws() { - await Assert.ThrowsAsync("clientTransport", () => McpClientFactory.CreateAsync(null!, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync("clientTransport", () => McpClient.CreateAsync(null!, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] public async Task CreateAsync_NopTransport_ReturnsClient() { // Act - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); @@ -39,7 +37,7 @@ public async Task Cancellation_ThrowsCancellationException(bool preCanceled) cts.Cancel(); } - Task t = McpClientFactory.CreateAsync( + Task t = McpClient.CreateAsync( new StreamClientTransport(new Pipe().Writer.AsStream(), new Pipe().Reader.AsStream()), cancellationToken: cts.Token); if (!preCanceled) @@ -85,9 +83,9 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) }; var clientTransport = (IClientTransport)Activator.CreateInstance(transportType)!; - IMcpClient? client = null; + McpClient? client = null; - var actionTask = McpClientFactory.CreateAsync(clientTransport, clientOptions, new Mock().Object, CancellationToken.None); + var actionTask = McpClient.CreateAsync(clientTransport, clientOptions, loggerFactory: null, CancellationToken.None); // Act if (clientTransport is FailureTransport) diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTest.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTest.cs new file mode 100644 index 00000000..f4e6062d --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTest.cs @@ -0,0 +1,387 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using Moq; +using System.Text.Json; + +namespace ModelContextProtocol.Tests; + +#pragma warning disable CS0618 // Type or member is obsolete + +public class McpClientExtensionsTests +{ + [Fact] + public async Task PingAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.PingAsync(TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.PingAsync' instead", ex.Message); + } + + [Fact] + public async Task GetPromptAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( + "name", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.GetPromptAsync' instead", ex.Message); + } + + [Fact] + public async Task CallToolAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.CallToolAsync( + "tool", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.CallToolAsync' instead", ex.Message); + } + + [Fact] + public async Task ListResourcesAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListResourcesAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListResourcesAsync' instead", ex.Message); + } + + [Fact] + public void EnumerateResourcesAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumerateResourcesAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumerateResourcesAsync' instead", ex.Message); + } + + [Fact] + public async Task SubscribeToResourceAsync_String_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.SubscribeToResourceAsync( + "mcp://resource/1", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.SubscribeToResourceAsync' instead", ex.Message); + } + + [Fact] + public async Task SubscribeToResourceAsync_Uri_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.SubscribeToResourceAsync( + new Uri("mcp://resource/1"), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.SubscribeToResourceAsync' instead", ex.Message); + } + + [Fact] + public async Task UnsubscribeFromResourceAsync_String_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.UnsubscribeFromResourceAsync( + "mcp://resource/1", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.UnsubscribeFromResourceAsync' instead", ex.Message); + } + + [Fact] + public async Task UnsubscribeFromResourceAsync_Uri_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.UnsubscribeFromResourceAsync( + new Uri("mcp://resource/1"), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.UnsubscribeFromResourceAsync' instead", ex.Message); + } + + [Fact] + public async Task ReadResourceAsync_String_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + "mcp://resource/1", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ReadResourceAsync' instead", ex.Message); + } + + [Fact] + public async Task ReadResourceAsync_Uri_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + new Uri("mcp://resource/1"), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ReadResourceAsync' instead", ex.Message); + } + + [Fact] + public async Task ReadResourceAsync_Template_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + "mcp://resource/{id}", new Dictionary { ["id"] = 1 }, TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ReadResourceAsync' instead", ex.Message); + } + + [Fact] + public async Task CompleteAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + var reference = new PromptReference { Name = "prompt" }; + + var ex = await Assert.ThrowsAsync(async () => await client.CompleteAsync( + reference, "arg", "val", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.CompleteAsync' instead", ex.Message); + } + + [Fact] + public async Task ListToolsAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListToolsAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListToolsAsync' instead", ex.Message); + } + + [Fact] + public void EnumerateToolsAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumerateToolsAsync' instead", ex.Message); + } + + [Fact] + public async Task ListPromptsAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListPromptsAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListPromptsAsync' instead", ex.Message); + } + + [Fact] + public void EnumeratePromptsAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumeratePromptsAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumeratePromptsAsync' instead", ex.Message); + } + + [Fact] + public async Task ListResourceTemplatesAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListResourceTemplatesAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListResourceTemplatesAsync' instead", ex.Message); + } + + [Fact] + public void EnumerateResourceTemplatesAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumerateResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumerateResourceTemplatesAsync' instead", ex.Message); + } + + [Fact] + public async Task PingAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(new object(), McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + await client.PingAsync(TestContext.Current.CancellationToken); + + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task GetPromptAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new GetPromptResult { Messages = [new PromptMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.GetPromptAsync("name", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("hi", Assert.IsType(result.Messages[0].Content).Text); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task CallToolAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var callResult = new CallToolResult { Content = [new TextContentBlock { Text = "ok" }] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(callResult, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.CallToolAsync("tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("ok", Assert.IsType(result.Content[0]).Text); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SubscribeToResourceAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(new EmptyResult(), McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + await client.SubscribeToResourceAsync("mcp://resource/1", TestContext.Current.CancellationToken); + + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task UnsubscribeFromResourceAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(new EmptyResult(), McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + await client.UnsubscribeFromResourceAsync("mcp://resource/1", TestContext.Current.CancellationToken); + + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task CompleteAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var completion = new Completion { Values = ["one", "two"] }; + var resultPayload = new CompleteResult { Completion = completion }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.CompleteAsync(new PromptReference { Name = "p" }, "arg", "val", TestContext.Current.CancellationToken); + + Assert.Contains("one", result.Completion.Values); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ReadResourceAsync_String_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new ReadResourceResult { Contents = [] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.ReadResourceAsync("mcp://resource/1", TestContext.Current.CancellationToken); + + Assert.NotNull(result); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ReadResourceAsync_Uri_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new ReadResourceResult { Contents = [] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.ReadResourceAsync(new Uri("mcp://resource/1"), TestContext.Current.CancellationToken); + + Assert.NotNull(result); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ReadResourceAsync_Template_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new ReadResourceResult { Contents = [] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.ReadResourceAsync("mcp://resource/{id}", new Dictionary { ["id"] = 1 }, TestContext.Current.CancellationToken); + + Assert.NotNull(result); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs index 48c3c370..2599d748 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs @@ -73,7 +73,7 @@ public static IEnumerable UriTemplate_InputsProduceExpectedOutputs_Mem public async Task UriTemplate_InputsProduceExpectedOutputs( IReadOnlyDictionary variables, string uriTemplate, object expected) { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.ReadResourceAsync(uriTemplate, variables, TestContext.Current.CancellationToken); Assert.NotNull(result); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs similarity index 95% rename from tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs rename to tests/ModelContextProtocol.Tests/Client/McpClientTests.cs index e3d7ce44..779e31e6 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -11,9 +11,9 @@ namespace ModelContextProtocol.Tests.Client; -public class McpClientExtensionsTests : ClientServerTestBase +public class McpClientTests : ClientServerTestBase { - public McpClientExtensionsTests(ITestOutputHelper outputHelper) + public McpClientTests(ITestOutputHelper outputHelper) : base(outputHelper) { } @@ -197,7 +197,7 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages() [Fact] public async Task ListToolsAsync_AllToolsReturned() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(12, tools.Count); @@ -223,7 +223,7 @@ public async Task ListToolsAsync_AllToolsReturned() [Fact] public async Task EnumerateToolsAsync_AllToolsReturned() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)) { @@ -242,7 +242,7 @@ public async Task EnumerateToolsAsync_AllToolsReturned() public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() { JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); bool hasTools = false; await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken)) @@ -263,7 +263,7 @@ public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First(); await Assert.ThrowsAsync(async () => await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken)); @@ -273,7 +273,7 @@ public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() public async Task SendRequestAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.SendRequestAsync("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); } @@ -282,7 +282,7 @@ public async Task SendRequestAsync_HonorsJsonSerializerOptions() public async Task SendNotificationAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); } @@ -291,7 +291,7 @@ public async Task SendNotificationAsync_HonorsJsonSerializerOptions() public async Task GetPromptsAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.GetPromptAsync("Prompt", new Dictionary { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); } @@ -300,7 +300,7 @@ public async Task GetPromptsAsync_HonorsJsonSerializerOptions() public async Task WithName_ChangesToolName() { JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First(); var originalName = tool.Name; @@ -315,7 +315,7 @@ public async Task WithName_ChangesToolName() public async Task WithDescription_ChangesToolDescription() { JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault(); var originalDescription = tool?.Description; var redescribedTool = tool?.WithDescription("ToolWithNewDescription"); @@ -344,7 +344,7 @@ public async Task WithProgress_ProgressReported() return 42; }, new() { Name = "ProgressReporter" })); - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)).First(t => t.Name == "ProgressReporter"); @@ -372,7 +372,7 @@ private sealed class SynchronousProgress(Action callb [Fact] public async Task AsClientLoggerProvider_MessagesSentToClient() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); ILoggerProvider loggerProvider = Server.AsClientLoggerProvider(); Assert.Throws("categoryName", () => loggerProvider.CreateLogger(null!)); diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index ebc7171e..6f625866 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -41,8 +41,8 @@ public void Initialize(ILoggerFactory loggerFactory) _loggerFactory = loggerFactory; } - public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => - McpClientFactory.CreateAsync(new StdioClientTransport(clientId switch + public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => + McpClient.CreateAsync(new StdioClientTransport(clientId switch { "everything" => EverythingServerTransportOptions, "test_server" => TestServerTransportOptions, diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 3e4361a5..21168841 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -471,7 +471,7 @@ public async Task CallTool_Stdio_MemoryServer() ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } }; - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new StdioClientTransport(stdioOptions), clientOptions, loggerFactory: LoggerFactory, @@ -495,7 +495,7 @@ public async Task CallTool_Stdio_MemoryServer() public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() { // Get the MCP client and tools from it. - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new StdioClientTransport(_fixture.EverythingServerTransportOptions), cancellationToken: TestContext.Current.CancellationToken); var mappedTools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -527,7 +527,7 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() var samplingHandler = new OpenAIClient(s_openAIKey).GetChatClient("gpt-4o-mini") .AsIChatClient() .CreateSamplingHandler(); - await using var client = await McpClientFactory.CreateAsync(new StdioClientTransport(_fixture.EverythingServerTransportOptions), new() + await using var client = await McpClient.CreateAsync(new StdioClientTransport(_fixture.EverythingServerTransportOptions), new() { Capabilities = new() { diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index ec1c8510..7dcebfd7 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -28,11 +28,11 @@ public ClientServerTestBase(ITestOutputHelper testOutputHelper) ServiceProvider = sc.BuildServiceProvider(validateScopes: true); _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - Server = ServiceProvider.GetRequiredService(); + Server = ServiceProvider.GetRequiredService(); _serverTask = Server.RunAsync(_cts.Token); } - protected IMcpServer Server { get; } + protected McpServer Server { get; } protected IServiceProvider ServiceProvider { get; } @@ -62,9 +62,9 @@ public async ValueTask DisposeAsync() Dispose(); } - protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) + protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) { - return await McpClientFactory.CreateAsync( + return await McpClient.CreateAsync( new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), _serverToClientPipe.Reader.AsStream(), diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index d697b979..061098b5 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -98,7 +98,7 @@ public void Adds_Prompts_To_Server() [Fact] public async Task Can_List_And_Call_Registered_Prompts() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -127,7 +127,7 @@ public async Task Can_List_And_Call_Registered_Prompts() [Fact] public async Task Can_Be_Notified_Of_Prompt_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -168,7 +168,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(prompts); @@ -182,7 +182,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task Throws_When_Prompt_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.GetPromptAsync( nameof(SimplePrompts.ThrowsException), @@ -192,7 +192,7 @@ await Assert.ThrowsAsync(async () => await client.GetPromptAsync( [Fact] public async Task Throws_Exception_On_Unknown_Prompt() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "NotRegisteredPrompt", @@ -204,7 +204,7 @@ public async Task Throws_Exception_On_Unknown_Prompt() [Fact] public async Task Throws_Exception_Missing_Parameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "returns_chat_messages", @@ -238,7 +238,7 @@ public async Task WithPrompts_TargetInstance_UsesTarget() sc.AddMcpServer().WithPrompts(target); McpServerPrompt prompt = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolPrompt.Name == "returns_string"); - var result = await prompt.GetAsync(new RequestContext(new Mock().Object) + var result = await prompt.GetAsync(new RequestContext(new Mock().Object) { Params = new GetPromptRequestParams { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index e6b177f5..f74eedad 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -126,7 +126,7 @@ public void Adds_Resources_To_Server() [Fact] public async Task Can_List_And_Call_Registered_Resources() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); Assert.NotNull(client.ServerCapabilities.Resources); @@ -145,7 +145,7 @@ public async Task Can_List_And_Call_Registered_Resources() [Fact] public async Task Can_List_And_Call_Registered_ResourceTemplates() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var resources = await client.ListResourceTemplatesAsync(TestContext.Current.CancellationToken); Assert.Equal(3, resources.Count); @@ -162,7 +162,7 @@ public async Task Can_List_And_Call_Registered_ResourceTemplates() [Fact] public async Task Can_Be_Notified_Of_Resource_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var resources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); Assert.Equal(5, resources.Count); @@ -203,7 +203,7 @@ public async Task Can_Be_Notified_Of_Resource_Changes() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(resources); @@ -221,7 +221,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task Throws_When_Resource_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( $"resource://mcp/{nameof(SimpleResources.ThrowsException)}", @@ -231,7 +231,7 @@ await Assert.ThrowsAsync(async () => await client.ReadResourceAsyn [Fact] public async Task Throws_Exception_On_Unknown_Resource() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( "test:///NotRegisteredResource", @@ -265,7 +265,7 @@ public async Task WithResources_TargetInstance_UsesTarget() sc.AddMcpServer().WithResources(target); McpServerResource resource = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolResource?.Name == "returns_string"); - var result = await resource.ReadAsync(new RequestContext(new Mock().Object) + var result = await resource.ReadAsync(new RequestContext(new Mock().Object) { Params = new() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index dbea036d..8366646b 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -123,7 +123,7 @@ public void Adds_Tools_To_Server() [Fact] public async Task Can_List_Registered_Tools() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -152,10 +152,10 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var stdoutPipe = new Pipe(); await using var transport = new StreamServerTransport(stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); - await using var server = McpServerFactory.Create(transport, options, loggerFactory, ServiceProvider); + await using var server = McpServer.Create(transport, options, loggerFactory, ServiceProvider); var serverRunTask = server.RunAsync(TestContext.Current.CancellationToken); - await using (var client = await McpClientFactory.CreateAsync( + await using (var client = await McpClient.CreateAsync( new StreamClientTransport( serverInput: stdinPipe.Writer.AsStream(), serverOutput: stdoutPipe.Reader.AsStream(), @@ -187,7 +187,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T [Fact] public async Task Can_Be_Notified_Of_Tool_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -228,7 +228,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() [Fact] public async Task Can_Call_Registered_Tool() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo", @@ -247,7 +247,7 @@ public async Task Can_Call_Registered_Tool() [Fact] public async Task Can_Call_Registered_Tool_With_Array_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo_array", @@ -270,7 +270,7 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Null_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_null", @@ -284,7 +284,7 @@ public async Task Can_Call_Registered_Tool_With_Null_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Json_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_json", @@ -301,7 +301,7 @@ public async Task Can_Call_Registered_Tool_With_Json_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Int_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_integer", @@ -316,7 +316,7 @@ public async Task Can_Call_Registered_Tool_With_Int_Result() [Fact] public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo_complex", @@ -333,7 +333,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() [Fact] public async Task Can_Call_Registered_Tool_With_Instance_Method() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); string[][] parts = new string[2][]; for (int i = 0; i < 2; i++) @@ -362,7 +362,7 @@ public async Task Can_Call_Registered_Tool_With_Instance_Method() [Fact] public async Task Returns_IsError_Content_When_Tool_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "throw_exception", @@ -377,7 +377,7 @@ public async Task Returns_IsError_Content_When_Tool_Fails() [Fact] public async Task Throws_Exception_On_Unknown_Tool() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( "NotRegisteredTool", @@ -389,7 +389,7 @@ public async Task Throws_Exception_On_Unknown_Tool() [Fact] public async Task Returns_IsError_Missing_Parameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo", @@ -516,7 +516,7 @@ public async Task WithTools_TargetInstance_UsesTarget() sc.AddMcpServer().WithTools(target, BuilderToolsJsonContext.Default.Options); McpServerTool tool = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolTool.Name == "get_ctor_parameter"); - var result = await tool.InvokeAsync(new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); + var result = await tool.InvokeAsync(new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal(target.GetCtorParameter(), (result.Content[0] as TextContentBlock)?.Text); } @@ -548,7 +548,7 @@ public IEnumerator GetEnumerator() [Fact] public async Task Recognizes_Parameter_Types() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -623,7 +623,7 @@ public void Create_ExtractsToolAnnotations_SomeSet() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -639,7 +639,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task HandlesIProgressParameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -693,7 +693,7 @@ public async Task HandlesIProgressParameter() [Fact] public async Task CancellationNotificationsPropagateToToolTokens() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs index b940c1c7..5ddc3c54 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs @@ -22,7 +22,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task InjectScopedServiceAsArgument() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(McpServerScopedTestsJsonContext.Default.Options, TestContext.Current.CancellationToken); var tool = tools.First(t => t.Name == "echo_complex"); diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index 116c62a1..5ad30d28 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -128,7 +128,7 @@ await RunConnected(async (client, server) => Assert.Equal("-32602", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); } - private static async Task RunConnected(Func action, List clientToServerLog) + private static async Task RunConnected(Func action, List clientToServerLog) { Pipe clientToServerPipe = new(), serverToClientPipe = new(); StreamServerTransport serverTransport = new(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()); @@ -137,7 +137,7 @@ private static async Task RunConnected(Func action Task serverTask; - await using (IMcpServer server = McpServerFactory.Create(serverTransport, new() + await using (McpServer server = McpServer.Create(serverTransport, new() { Capabilities = new() { @@ -153,7 +153,7 @@ private static async Task RunConnected(Func action { serverTask = server.RunAsync(TestContext.Current.CancellationToken); - await using (IMcpClient client = await McpClientFactory.CreateAsync( + await using (McpClient client = await McpClient.CreateAsync( clientTransport, cancellationToken: TestContext.Current.CancellationToken)) { diff --git a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs index ffd95076..e3faf05f 100644 --- a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs +++ b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs @@ -43,7 +43,7 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() }; // Create client and run tests - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new SseClientTransport(defaultConfig), defaultOptions, loggerFactory: LoggerFactory, @@ -90,7 +90,7 @@ public async Task Sampling_Sse_EverythingServer() }, }; - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new SseClientTransport(defaultConfig), defaultOptions, loggerFactory: LoggerFactory, diff --git a/tests/ModelContextProtocol.Tests/McpEndpointExtensionsTests.cs b/tests/ModelContextProtocol.Tests/McpEndpointExtensionsTests.cs new file mode 100644 index 00000000..613c703c --- /dev/null +++ b/tests/ModelContextProtocol.Tests/McpEndpointExtensionsTests.cs @@ -0,0 +1,118 @@ +using ModelContextProtocol.Protocol; +using Moq; +using System.Text.Json; + +namespace ModelContextProtocol.Tests; + +#pragma warning disable CS0618 // Type or member is obsolete + +public class McpEndpointExtensionsTests +{ + [Fact] + public async Task SendRequestAsync_Generic_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.SendRequestAsync( + endpoint, "method", "param", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SendRequestAsync' instead", ex.Message); + } + + [Fact] + public async Task SendNotificationAsync_Parameterless_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.SendNotificationAsync( + endpoint, "notify", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SendNotificationAsync' instead", ex.Message); + } + + [Fact] + public async Task SendNotificationAsync_Generic_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.SendNotificationAsync( + endpoint, "notify", "payload", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SendNotificationAsync' instead", ex.Message); + } + + [Fact] + public async Task NotifyProgressAsync_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.NotifyProgressAsync( + endpoint, new ProgressToken("t1"), new ProgressNotificationValue { Progress = 0.5f }, cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.NotifyProgressAsync' instead", ex.Message); + } + + [Fact] + public async Task SendRequestAsync_Generic_Forwards_To_McpSession_SendRequestAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(42, McpJsonUtilities.DefaultOptions), + }); + + IMcpEndpoint endpoint = mockSession.Object; + + var result = await endpoint.SendRequestAsync("method", "param", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal(42, result); + mockSession.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SendNotificationAsync_Parameterless_Forwards_To_McpSession_SendMessageAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + IMcpEndpoint endpoint = mockSession.Object; + + await endpoint.SendNotificationAsync("notify", cancellationToken: TestContext.Current.CancellationToken); + + mockSession.Verify(s => s.SendMessageAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SendNotificationAsync_Generic_Forwards_To_McpSession_SendMessageAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + IMcpEndpoint endpoint = mockSession.Object; + + await endpoint.SendNotificationAsync("notify", "payload", cancellationToken: TestContext.Current.CancellationToken); + + mockSession.Verify(s => s.SendMessageAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task NotifyProgressAsync_Forwards_To_McpSession_SendMessageAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + IMcpEndpoint endpoint = mockSession.Object; + + await endpoint.NotifyProgressAsync(new ProgressToken("progress-token"), new ProgressNotificationValue { Progress = 1 }, cancellationToken: TestContext.Current.CancellationToken); + + mockSession.Verify(s => s.SendMessageAsync(It.IsAny(), It.IsAny()), Times.Once); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs index f4474391..22fd69c1 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs @@ -67,7 +67,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task Can_Elicit_Information() { - await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions { Capabilities = new() { diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index 0d18667e..25470650 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -13,7 +13,7 @@ public NotificationHandlerTests(ITestOutputHelper testOutputHelper) public async Task RegistrationsAreRemovedWhenDisposed() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); const int Iterations = 10; @@ -40,7 +40,7 @@ public async Task RegistrationsAreRemovedWhenDisposed() public async Task MultipleRegistrationsResultInMultipleCallbacks() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -80,7 +80,7 @@ public async Task MultipleRegistrationsResultInMultipleCallbacks() public async Task MultipleHandlersRunEvenIfOneThrows() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -122,7 +122,7 @@ public async Task MultipleHandlersRunEvenIfOneThrows() public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -163,7 +163,7 @@ public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int nu public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs new file mode 100644 index 00000000..5569f993 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs @@ -0,0 +1,195 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using Moq; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Server; + +#pragma warning disable CS0618 // Type or member is obsolete + +public class McpServerExtensionsTests +{ + [Fact] + public async Task SampleAsync_Request_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.SampleAsync( + new CreateMessageRequestParams { Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] }, + TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SampleAsync' instead", ex.Message); + } + + [Fact] + public async Task SampleAsync_Messages_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.SampleAsync( + [new ChatMessage(ChatRole.User, "hi")], cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SampleAsync' instead", ex.Message); + } + + [Fact] + public void AsSamplingChatClient_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(server.AsSamplingChatClient); + Assert.Contains("Prefer using 'McpServer.AsSamplingChatClient' instead", ex.Message); + } + + [Fact] + public void AsClientLoggerProvider_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(server.AsClientLoggerProvider); + Assert.Contains("Prefer using 'McpServer.AsClientLoggerProvider' instead", ex.Message); + } + + [Fact] + public async Task RequestRootsAsync_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.RequestRootsAsync( + new ListRootsRequestParams(), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.RequestRootsAsync' instead", ex.Message); + } + + [Fact] + public async Task ElicitAsync_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.ElicitAsync( + new ElicitRequestParams { Message = "hello" }, TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.ElicitAsync' instead", ex.Message); + } + + [Fact] + public async Task SampleAsync_Request_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new CreateMessageResult + { + Content = new TextContentBlock { Text = "resp" }, + Model = "test-model", + Role = Role.Assistant, + StopReason = "endTurn", + }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Sampling = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var result = await server.SampleAsync(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] + }, TestContext.Current.CancellationToken); + + Assert.Equal("test-model", result.Model); + Assert.Equal(Role.Assistant, result.Role); + Assert.Equal("resp", Assert.IsType(result.Content).Text); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SampleAsync_Messages_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new CreateMessageResult + { + Content = new TextContentBlock { Text = "resp" }, + Model = "test-model", + Role = Role.Assistant, + StopReason = "endTurn", + }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Sampling = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var chatResponse = await server.SampleAsync([new ChatMessage(ChatRole.User, "hi")], cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("test-model", chatResponse.ModelId); + var last = chatResponse.Messages.Last(); + Assert.Equal(ChatRole.Assistant, last.Role); + Assert.Equal("resp", last.Text); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task RequestRootsAsync_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new ListRootsResult { Roots = [new Root { Uri = "root://a" }] }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Roots = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var result = await server.RequestRootsAsync(new ListRootsRequestParams(), TestContext.Current.CancellationToken); + + Assert.Equal("root://a", result.Roots[0].Uri); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ElicitAsync_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new ElicitResult { Action = "accept" }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Elicitation = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var result = await server.ElicitAsync(new ElicitRequestParams { Message = "hi" }, TestContext.Current.CancellationToken); + + Assert.Equal("accept", result.Action); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs deleted file mode 100644 index 034a30bd..00000000 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ /dev/null @@ -1,45 +0,0 @@ -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; - -namespace ModelContextProtocol.Tests.Server; - -public class McpServerFactoryTests : LoggedTest -{ - private readonly McpServerOptions _options; - - public McpServerFactoryTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - _options = new McpServerOptions - { - ProtocolVersion = "1.0", - InitializationTimeout = TimeSpan.FromSeconds(30) - }; - } - - [Fact] - public async Task Create_Should_Initialize_With_Valid_Parameters() - { - // Arrange & Act - await using var transport = new TestServerTransport(); - await using IMcpServer server = McpServerFactory.Create(transport, _options, LoggerFactory); - - // Assert - Assert.NotNull(server); - } - - [Fact] - public void Create_Throws_For_Null_ServerTransport() - { - // Arrange, Act & Assert - Assert.Throws("transport", () => McpServerFactory.Create(null!, _options, LoggerFactory)); - } - - [Fact] - public async Task Create_Throws_For_Null_Options() - { - // Arrange, Act & Assert - await using var transport = new TestServerTransport(); - Assert.Throws("serverOptions", () => McpServerFactory.Create(transport, null!, LoggerFactory)); - } -} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs index b2e74873..be271a68 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs @@ -25,7 +25,7 @@ public void CanCreateServerWithLoggingLevelHandler() var provider = services.BuildServiceProvider(); - provider.GetRequiredService(); + provider.GetRequiredService(); } [Fact] @@ -39,7 +39,7 @@ public void AddingLoggingLevelHandlerSetsLoggingCapability() var provider = services.BuildServiceProvider(); - var server = provider.GetRequiredService(); + var server = provider.GetRequiredService(); Assert.NotNull(server.ServerOptions.Capabilities?.Logging); Assert.NotNull(server.ServerOptions.Capabilities.Logging.SetLoggingLevelHandler); @@ -52,7 +52,7 @@ public void ServerWithoutCallingLoggingLevelHandlerDoesNotSetLoggingCapability() services.AddMcpServer() .WithStdioServerTransport(); var provider = services.BuildServiceProvider(); - var server = provider.GetRequiredService(); + var server = provider.GetRequiredService(); Assert.Null(server.ServerOptions.Capabilities?.Logging); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index 39e9b72f..d49aff5b 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -35,9 +35,9 @@ public void Create_InvalidArgs_Throws() [Fact] public async Task SupportsIMcpServer() { - Mock mockServer = new(); + Mock mockServer = new(); - McpServerPrompt prompt = McpServerPrompt.Create((IMcpServer server) => + McpServerPrompt prompt = McpServerPrompt.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new ChatMessage(ChatRole.User, "Hello"); @@ -63,7 +63,7 @@ public async Task SupportsCtorInjection() sc.AddSingleton(expectedMyService); IServiceProvider services = sc.BuildServiceProvider(); - Mock mockServer = new(); + Mock mockServer = new(); mockServer.SetupGet(s => s.Services).Returns(services); MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestPrompt)); @@ -86,11 +86,11 @@ public async Task SupportsCtorInjection() private sealed class HasCtorWithSpecialParameters { private readonly MyService _ms; - private readonly IMcpServer _server; + private readonly McpServer _server; private readonly RequestContext _request; private readonly IProgress _progress; - public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + public HasCtorWithSpecialParameters(MyService ms, McpServer server, RequestContext request, IProgress progress) { Assert.NotNull(ms); Assert.NotNull(server); @@ -125,11 +125,11 @@ public async Task SupportsServiceFromDI() Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); await Assert.ThrowsAnyAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken)); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object) { Services = services }, + new RequestContext(new Mock().Object) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -150,7 +150,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -163,7 +163,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("disposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -176,7 +176,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() _ => new AsyncDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -189,7 +189,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable _ => new AsyncDisposableAndDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("disposals:0, asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -205,7 +205,7 @@ public async Task CanReturnGetPromptResult() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Same(expected, actual); @@ -222,7 +222,7 @@ public async Task CanReturnText() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -248,7 +248,7 @@ public async Task CanReturnPromptMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -280,7 +280,7 @@ public async Task CanReturnPromptMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -307,7 +307,7 @@ public async Task CanReturnChatMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -339,7 +339,7 @@ public async Task CanReturnChatMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -360,7 +360,7 @@ public async Task ThrowsForNullReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken)); } @@ -373,7 +373,7 @@ public async Task ThrowsForUnexpectedTypeReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index 011c4f2b..9e688455 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -50,7 +50,7 @@ public void CanCreateServerWithResource() var provider = services.BuildServiceProvider(); - provider.GetRequiredService(); + provider.GetRequiredService(); } @@ -86,7 +86,7 @@ public void CanCreateServerWithResourceTemplates() var provider = services.BuildServiceProvider(); - provider.GetRequiredService(); + provider.GetRequiredService(); } [Fact] @@ -109,7 +109,7 @@ public void CreatingReadHandlerWithNoListHandlerSucceeds() }); var sp = services.BuildServiceProvider(); - sp.GetRequiredService(); + sp.GetRequiredService(); } [Fact] @@ -133,7 +133,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() McpServerResource t; ReadResourceResult? result; - IMcpServer server = new Mock().Object; + McpServer server = new Mock().Object; t = McpServerResource.Create(() => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); @@ -143,7 +143,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); - t = McpServerResource.Create((IMcpServer server) => "42", new() { Name = Name }); + t = McpServerResource.Create((McpServer server) => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( new RequestContext(server) { Params = new() { Uri = "resource://mcp/Hello" } }, @@ -277,7 +277,7 @@ public async Task UriTemplate_NonMatchingUri_ReturnsNull(string uri) McpServerResource t = McpServerResource.Create((string arg1) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1}", t.ProtocolResourceTemplate.UriTemplate); Assert.Null(await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -288,7 +288,7 @@ public async Task UriTemplate_IsHostCaseInsensitive(string actualUri, string que { McpServerResource t = McpServerResource.Create(() => "resource", new() { UriTemplate = actualUri }); Assert.NotNull(await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = queriedUri } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = queriedUri } }, TestContext.Current.CancellationToken)); } @@ -317,7 +317,7 @@ public async Task UriTemplate_MissingParameter_Throws(string uri) McpServerResource t = McpServerResource.Create((string arg1, int arg2) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1,arg2}", t.ProtocolResourceTemplate.UriTemplate); await Assert.ThrowsAsync(async () => await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -330,25 +330,25 @@ public async Task UriTemplate_MissingOptionalParameter_Succeeds() ReadResourceResult? result; result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first42", ((TextResourceContents)result.Contents[0]).Text); @@ -357,9 +357,9 @@ public async Task UriTemplate_MissingOptionalParameter_Succeeds() [Fact] public async Task SupportsIMcpServer() { - Mock mockServer = new(); + Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; @@ -381,7 +381,7 @@ public async Task SupportsCtorInjection() sc.AddSingleton(expectedMyService); IServiceProvider services = sc.BuildServiceProvider(); - Mock mockServer = new(); + Mock mockServer = new(); mockServer.SetupGet(s => s.Services).Returns(services); MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestResource)); @@ -404,11 +404,11 @@ public async Task SupportsCtorInjection() private sealed class HasCtorWithSpecialParameters { private readonly MyService _ms; - private readonly IMcpServer _server; + private readonly McpServer _server; private readonly RequestContext _request; private readonly IProgress _progress; - public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + public HasCtorWithSpecialParameters(MyService ms, McpServer server, RequestContext request, IProgress progress) { Assert.NotNull(ms); Assert.NotNull(server); @@ -467,7 +467,7 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime McpServerResource resource = services.GetRequiredService(); - Mock mockServer = new(); + Mock mockServer = new(); await Assert.ThrowsAnyAsync(async () => await resource.ReadAsync( new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, @@ -496,7 +496,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services, Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -512,7 +512,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposableResourceType()); var result = await resource1.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("0", ((TextResourceContents)result.Contents[0]).Text); @@ -523,8 +523,8 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() [Fact] public async Task CanReturnReadResult() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new ReadResourceResult { Contents = new List { new TextResourceContents { Text = "hello" } } }; @@ -540,8 +540,8 @@ public async Task CanReturnReadResult() [Fact] public async Task CanReturnResourceContents() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new TextResourceContents { Text = "hello" }; @@ -557,8 +557,8 @@ public async Task CanReturnResourceContents() [Fact] public async Task CanReturnCollectionOfResourceContents() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return (IList) @@ -579,8 +579,8 @@ public async Task CanReturnCollectionOfResourceContents() [Fact] public async Task CanReturnString() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; @@ -596,8 +596,8 @@ public async Task CanReturnString() [Fact] public async Task CanReturnCollectionOfStrings() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List { "42", "43" }; @@ -614,8 +614,8 @@ public async Task CanReturnCollectionOfStrings() [Fact] public async Task CanReturnDataContent() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new DataContent(new byte[] { 0, 1, 2 }, "application/octet-stream"); @@ -632,8 +632,8 @@ public async Task CanReturnDataContent() [Fact] public async Task CanReturnCollectionOfAIContent() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 6750b2ca..61cda701 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -32,12 +32,38 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = }; } + [Fact] + public async Task Create_Should_Initialize_With_Valid_Parameters() + { + // Arrange & Act + await using var transport = new TestServerTransport(); + await using McpServer server = McpServer.Create(transport, _options, LoggerFactory); + + // Assert + Assert.NotNull(server); + } + + [Fact] + public void Create_Throws_For_Null_ServerTransport() + { + // Arrange, Act & Assert + Assert.Throws("transport", () => McpServer.Create(null!, _options, LoggerFactory)); + } + + [Fact] + public async Task Create_Throws_For_Null_Options() + { + // Arrange, Act & Assert + await using var transport = new TestServerTransport(); + Assert.Throws("serverOptions", () => McpServer.Create(transport, null!, LoggerFactory)); + } + [Fact] public async Task Constructor_Should_Initialize_With_Valid_Parameters() { // Arrange & Act await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); // Assert Assert.NotNull(server); @@ -47,7 +73,7 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() public void Constructor_Throws_For_Null_Transport() { // Arrange, Act & Assert - Assert.Throws(() => McpServerFactory.Create(null!, _options, LoggerFactory)); + Assert.Throws(() => McpServer.Create(null!, _options, LoggerFactory)); } [Fact] @@ -55,7 +81,7 @@ public async Task Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert await using var transport = new TestServerTransport(); - Assert.Throws(() => McpServerFactory.Create(transport, null!, LoggerFactory)); + Assert.Throws(() => McpServer.Create(transport, null!, LoggerFactory)); } [Fact] @@ -63,7 +89,7 @@ public async Task Constructor_Does_Not_Throw_For_Null_Logger() { // Arrange & Act await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, null); + await using var server = McpServer.Create(transport, _options, null); // Assert Assert.NotNull(server); @@ -74,7 +100,7 @@ public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() { // Arrange & Act await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, null); + await using var server = McpServer.Create(transport, _options, LoggerFactory, null); // Assert Assert.NotNull(server); @@ -85,7 +111,7 @@ public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Run { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); // Act & Assert @@ -100,7 +126,7 @@ public async Task SampleAsync_Should_Throw_Exception_If_Client_Does_Not_Support_ { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); var action = async () => await server.SampleAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -114,7 +140,7 @@ public async Task SampleAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -136,7 +162,7 @@ public async Task RequestRootsAsync_Should_Throw_Exception_If_Client_Does_Not_Su { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert @@ -148,7 +174,7 @@ public async Task RequestRootsAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -170,7 +196,7 @@ public async Task ElicitAsync_Should_Throw_Exception_If_Client_Does_Not_Support_ { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert @@ -182,7 +208,7 @@ public async Task ElicitAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Elicitation = new ElicitationCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -216,7 +242,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_Initialize_Requests() { - AssemblyName expectedAssemblyName = (Assembly.GetEntryAssembly() ?? typeof(IMcpServer).Assembly).GetName(); + AssemblyName expectedAssemblyName = (Assembly.GetEntryAssembly() ?? typeof(McpServer).Assembly).GetName(); await Can_Handle_Requests( serverCapabilities: null, method: RequestMethods.Initialize, @@ -510,7 +536,7 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s var options = CreateOptions(serverCapabilities); configureOptions?.Invoke(options); - await using var server = McpServerFactory.Create(transport, options, LoggerFactory); + await using var server = McpServer.Create(transport, options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -544,7 +570,7 @@ private async Task Succeeds_Even_If_No_Handler_Assigned(ServerCapabilities serve await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); - var server = McpServerFactory.Create(transport, options, LoggerFactory); + var server = McpServer.Create(transport, options, LoggerFactory); await server.DisposeAsync(); } @@ -589,7 +615,7 @@ public async Task AsSamplingChatClient_HandlesRequestResponse() public async Task Can_SendMessage_Before_RunAsync() { await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); var logNotification = new JsonRpcNotification { @@ -605,22 +631,22 @@ public async Task Can_SendMessage_Before_RunAsync() Assert.Same(logNotification, transport.SentMessages[0]); } - private static void SetClientCapabilities(IMcpServer server, ClientCapabilities capabilities) + private static void SetClientCapabilities(McpServer server, ClientCapabilities capabilities) { - PropertyInfo? property = server.GetType().GetProperty("ClientCapabilities", BindingFlags.Public | BindingFlags.Instance); - Assert.NotNull(property); - property.SetValue(server, capabilities); + FieldInfo? field = server.GetType().GetField("_clientCapabilities", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(field); + field.SetValue(server, capabilities); } - private sealed class TestServerForIChatClient(bool supportsSampling) : IMcpServer + private sealed class TestServerForIChatClient(bool supportsSampling) : McpServer { - public ClientCapabilities? ClientCapabilities => + public override ClientCapabilities? ClientCapabilities => supportsSampling ? new ClientCapabilities { Sampling = new SamplingCapability() } : null; - public McpServerOptions ServerOptions => new(); + public override McpServerOptions ServerOptions => new(); - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) { CreateMessageRequestParams? rp = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.DefaultOptions); @@ -653,17 +679,17 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati }); } - public ValueTask DisposeAsync() => default; + public override ValueTask DisposeAsync() => default; - public string? SessionId => throw new NotImplementedException(); - public Implementation? ClientInfo => throw new NotImplementedException(); - public IServiceProvider? Services => throw new NotImplementedException(); - public LoggingLevel? LoggingLevel => throw new NotImplementedException(); - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => + public override string? SessionId => throw new NotImplementedException(); + public override Implementation? ClientInfo => throw new NotImplementedException(); + public override IServiceProvider? Services => throw new NotImplementedException(); + public override LoggingLevel? LoggingLevel => throw new NotImplementedException(); + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task RunAsync(CancellationToken cancellationToken = default) => + public override Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => throw new NotImplementedException(); } @@ -683,7 +709,7 @@ public async Task NotifyProgress_Should_Be_Handled() })], }; - var server = McpServerFactory.Create(transport, options, LoggerFactory); + var server = McpServer.Create(transport, options, LoggerFactory); Task serverTask = server.RunAsync(TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index f961eef3..c9cee114 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -42,9 +42,9 @@ public void Create_InvalidArgs_Throws() [Fact] public async Task SupportsIMcpServer() { - Mock mockServer = new(); + Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; @@ -67,7 +67,7 @@ public async Task SupportsCtorInjection() sc.AddSingleton(expectedMyService); IServiceProvider services = sc.BuildServiceProvider(); - Mock mockServer = new(); + Mock mockServer = new(); mockServer.SetupGet(s => s.Services).Returns(services); MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestTool)); @@ -90,11 +90,11 @@ public async Task SupportsCtorInjection() private sealed class HasCtorWithSpecialParameters { private readonly MyService _ms; - private readonly IMcpServer _server; + private readonly McpServer _server; private readonly RequestContext _request; private readonly IProgress _progress; - public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + public HasCtorWithSpecialParameters(MyService ms, McpServer server, RequestContext request, IProgress progress) { Assert.NotNull(ms); Assert.NotNull(server); @@ -154,7 +154,7 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Assert.DoesNotContain("actualMyService", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema, McpJsonUtilities.DefaultOptions)); - Mock mockServer = new(); + Mock mockServer = new(); var result = await tool.InvokeAsync( new RequestContext(mockServer.Object), @@ -183,7 +183,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await tool.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -198,7 +198,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("""{"disposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -213,7 +213,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -232,7 +232,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object) { Services = services }, + new RequestContext(new Mock().Object) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -241,8 +241,8 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable [Fact] public async Task CanReturnCollectionOfAIContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List { @@ -273,8 +273,8 @@ public async Task CanReturnCollectionOfAIContent() [InlineData("data:audio/wav;base64,1234", "audio")] public async Task CanReturnSingleAIContent(string data, string type) { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return type switch @@ -316,8 +316,8 @@ public async Task CanReturnSingleAIContent(string data, string type) [Fact] public async Task CanReturnNullAIContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return (string?)null; @@ -331,8 +331,8 @@ public async Task CanReturnNullAIContent() [Fact] public async Task CanReturnString() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; @@ -347,8 +347,8 @@ public async Task CanReturnString() [Fact] public async Task CanReturnCollectionOfStrings() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List { "42", "43" }; @@ -363,8 +363,8 @@ public async Task CanReturnCollectionOfStrings() [Fact] public async Task CanReturnMcpContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new TextContentBlock { Text = "42" }; @@ -380,8 +380,8 @@ public async Task CanReturnMcpContent() [Fact] public async Task CanReturnCollectionOfMcpContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return (IList) @@ -407,8 +407,8 @@ public async Task CanReturnCallToolResult() Content = new List { new TextContentBlock { Text = "text" }, new ImageContentBlock { Data = "1234", MimeType = "image/png" } } }; - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return response; @@ -465,7 +465,7 @@ public async Task ToolCallError_LogsErrorMessage() throw new InvalidOperationException(exceptionMessage); }, new() { Name = toolName, Services = serviceProvider }); - var mockServer = new Mock(); + var mockServer = new Mock(); var request = new RequestContext(mockServer.Object) { Params = new CallToolRequestParams { Name = toolName }, @@ -492,7 +492,7 @@ public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) { JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { Name = "tool", UseStructuredContent = true, SerializerOptions = options }); - var mockServer = new Mock(); + var mockServer = new Mock(); var request = new RequestContext(mockServer.Object) { Params = new CallToolRequestParams { Name = "tool" }, @@ -510,7 +510,7 @@ public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSchema() { McpServerTool tool = McpServerTool.Create(() => { }); - var mockServer = new Mock(); + var mockServer = new Mock(); var request = new RequestContext(mockServer.Object) { Params = new CallToolRequestParams { Name = "tool" }, @@ -550,7 +550,7 @@ public async Task StructuredOutput_Disabled_ReturnsExpectedSchema(T value) { JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { UseStructuredContent = false, SerializerOptions = options }); - var mockServer = new Mock(); + var mockServer = new Mock(); var request = new RequestContext(mockServer.Object) { Params = new CallToolRequestParams { Name = "tool" }, diff --git a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs index f3927be6..d14c376c 100644 --- a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs @@ -35,7 +35,7 @@ public async Task SigInt_DisposesTestServerWithHosting_Gracefully() process.StandardInput.BaseStream, serverName: "TestServerWithHosting"); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new TestClientTransport(streamServerTransport), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index dfde342a..48c2b953 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -8,7 +8,7 @@ namespace ModelContextProtocol.Tests.Transport; public class StdioClientTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { public static bool IsStdErrCallbackSupported => !PlatformDetection.IsMonoRuntime; - + [Fact] public async Task CreateAsync_ValidProcessInvalidServer_Throws() { @@ -18,13 +18,13 @@ public async Task CreateAsync_ValidProcessInvalidServer_Throws() new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }, LoggerFactory) : new(new() { Command = "ls", Arguments = [id] }, LoggerFactory); - IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + IOException e = await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { Assert.Contains(id, e.ToString()); } } - + [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() { @@ -46,7 +46,7 @@ public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"], StandardErrorLines = stdErrCallback }, LoggerFactory) : new(new() { Command = "ls", Arguments = [id], StandardErrorLines = stdErrCallback }, LoggerFactory); - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.InRange(count, 1, int.MaxValue); Assert.Contains(id, sb.ToString());