From fccff045d0191690749bc7f17168871adf8ed3a6 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 10 Apr 2025 22:59:58 -0400 Subject: [PATCH 1/3] Wrap each request in a service scope --- .../Server/AIFunctionMcpServerPrompt.cs | 6 +-- .../Server/AIFunctionMcpServerTool.cs | 6 +-- src/ModelContextProtocol/Server/McpServer.cs | 51 +++++++++++++++---- .../Server/McpServerOptions.cs | 9 ++++ .../Server/RequestContext.cs | 44 +++++++++++++++- .../ClientServerTestBase.cs | 2 +- .../McpServerBuilderExtensionsToolsTests.cs | 18 ------- .../Configuration/McpServerScopedTests.cs | 40 +++++++++++++++ .../ModelContextProtocol.Tests/JsonContext.cs | 20 ++++++++ .../Server/McpServerPromptTests.cs | 34 ++++++------- .../Server/McpServerToolTests.cs | 32 ++++++------ 11 files changed, 188 insertions(+), 74 deletions(-) create mode 100644 tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs create mode 100644 tests/ModelContextProtocol.Tests/JsonContext.cs diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs index 7684d6253..fbdf49520 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs @@ -99,7 +99,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - GetRequestContext(args)?.Server?.Services?.GetService(pi.ParameterType) ?? + GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -111,7 +111,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - (GetRequestContext(args)?.Server?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -214,7 +214,7 @@ public override async Task GetAsync( AIFunctionArguments arguments = new() { - Services = request.Server?.Services, + Services = request.Services, Context = new Dictionary() { [typeof(RequestContext)] = request } }; diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index abb5a197b..6fbf68376 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -120,7 +120,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - GetRequestContext(args)?.Server?.Services?.GetService(pi.ParameterType) ?? + GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -132,7 +132,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - (GetRequestContext(args)?.Server?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -245,7 +245,7 @@ public override async Task InvokeAsync( AIFunctionArguments arguments = new() { - Services = request.Server?.Services, + Services = request.Services, Context = new Dictionary() { [typeof(RequestContext)] = request } }; diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index a625e4c2e..5696efd8a 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -1,3 +1,4 @@ +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; @@ -19,6 +20,7 @@ internal sealed class McpServer : McpEndpoint, IMcpServer }; private readonly ITransport _sessionTransport; + private readonly bool _servicesScopePerRequest; private readonly EventHandler? _toolsChangedDelegate; private readonly EventHandler? _promptsChangedDelegate; @@ -54,6 +56,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? ServerOptions = options; Services = serviceProvider; _endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; + _servicesScopePerRequest = options.ScopeRequests; // Configure all request handlers based on the supplied options. SetInitializeHandler(options); @@ -196,7 +199,7 @@ private void SetCompletionHandler(McpServerOptions options) // This capability is not optional, so return an empty result if there is no handler. RequestHandlers.Set( RequestMethods.CompletionComplete, - (request, cancellationToken) => completeHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(completeHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.CompleteRequestParams, McpJsonUtilities.JsonContext.Default.CompleteResult); } @@ -221,20 +224,20 @@ private void SetResourcesHandler(McpServerOptions options) RequestHandlers.Set( RequestMethods.ResourcesList, - (request, cancellationToken) => listResourcesHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(listResourcesHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourcesResult); RequestHandlers.Set( RequestMethods.ResourcesRead, - (request, cancellationToken) => readResourceHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(readResourceHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, McpJsonUtilities.JsonContext.Default.ReadResourceResult); listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult())); RequestHandlers.Set( RequestMethods.ResourcesTemplatesList, - (request, cancellationToken) => listResourceTemplatesHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(listResourceTemplatesHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); @@ -252,13 +255,13 @@ private void SetResourcesHandler(McpServerOptions options) RequestHandlers.Set( RequestMethods.ResourcesSubscribe, - (request, cancellationToken) => subscribeHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(subscribeHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); RequestHandlers.Set( RequestMethods.ResourcesUnsubscribe, - (request, cancellationToken) => unsubscribeHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(unsubscribeHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); } @@ -345,13 +348,13 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals RequestHandlers.Set( RequestMethods.PromptsList, - (request, cancellationToken) => listPromptsHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(listPromptsHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, McpJsonUtilities.JsonContext.Default.ListPromptsResult); RequestHandlers.Set( RequestMethods.PromptsGet, - (request, cancellationToken) => getPromptHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(getPromptHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, McpJsonUtilities.JsonContext.Default.GetPromptResult); } @@ -438,13 +441,13 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) RequestHandlers.Set( RequestMethods.ToolsList, - (request, cancellationToken) => listToolsHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(listToolsHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, McpJsonUtilities.JsonContext.Default.ListToolsResult); RequestHandlers.Set( RequestMethods.ToolsCall, - (request, cancellationToken) => callToolHandler(new(this, request), cancellationToken), + (request, cancellationToken) => InvokeHandlerAsync(callToolHandler, request, cancellationToken), McpJsonUtilities.JsonContext.Default.CallToolRequestParams, McpJsonUtilities.JsonContext.Default.CallToolResponse); } @@ -473,7 +476,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) // If a handler was provided, now delegate to it. if (setLoggingLevelHandler is not null) { - return setLoggingLevelHandler(new(this, request), cancellationToken); + return InvokeHandlerAsync(setLoggingLevelHandler, request, cancellationToken); } // Otherwise, consider it handled. @@ -483,6 +486,32 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) McpJsonUtilities.JsonContext.Default.EmptyResult); } + private Task InvokeHandlerAsync( + Func, CancellationToken, Task> handler, + TParams? args, + CancellationToken cancellationToken) + { + return _servicesScopePerRequest ? + InvokeScopedAsync(handler, args, cancellationToken) : + handler(new(this) { Params = args }, cancellationToken); + + async Task InvokeScopedAsync( + Func, CancellationToken, Task> handler, + TParams? args, + CancellationToken cancellationToken) + { + using var scope = Services?.GetService()?.CreateScope(); + + return await handler( + new RequestContext(this) + { + Services = scope?.ServiceProvider ?? Services, + Params = args + }, + cancellationToken).ConfigureAwait(false); + } + } + /// Maps a to a . internal static LoggingLevel ToLoggingLevel(LogLevel level) => level switch diff --git a/src/ModelContextProtocol/Server/McpServerOptions.cs b/src/ModelContextProtocol/Server/McpServerOptions.cs index d3b9f0758..6880d2f2b 100644 --- a/src/ModelContextProtocol/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol/Server/McpServerOptions.cs @@ -56,4 +56,13 @@ public class McpServerOptions /// to provide context about available functionality. /// public string? ServerInstructions { get; set; } + + /// + /// Gets or sets whether to create a new service provider scope for each handled request. + /// + /// + /// The default is . When , each invocation of a request + /// handler will be invoked within a new service scope. + /// + public bool ScopeRequests { get; set; } = true; } diff --git a/src/ModelContextProtocol/Server/RequestContext.cs b/src/ModelContextProtocol/Server/RequestContext.cs index 9e3c946cb..88f2d13bf 100644 --- a/src/ModelContextProtocol/Server/RequestContext.cs +++ b/src/ModelContextProtocol/Server/RequestContext.cs @@ -1,9 +1,10 @@ using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Utils; namespace ModelContextProtocol.Server; /// -/// Provides a context container that provides access to both the server instance and the client request parameters. +/// Provides a context container that provides access to the client request parameters and resources for the request. /// /// Type of the request parameters specific to each MCP operation. /// @@ -11,4 +12,43 @@ namespace ModelContextProtocol.Server; /// This type is typically received as a parameter in handler delegates registered with , /// and may be injected as parameters into s. /// -public record RequestContext(IMcpServer Server, TParams? Params); \ No newline at end of file +public sealed class RequestContext +{ + /// The server with which this instance is associated. + private IMcpServer _server; + + /// + /// Initializes a new instance of the class with the specified server. + /// + /// The server with which this instance is associated. + public RequestContext(IMcpServer server) + { + Throw.IfNull(server); + + _server = server; + Services = server.Services; + } + + /// Gets or sets the server with which this instance is associated. + public IMcpServer Server + { + get => _server; + set + { + Throw.IfNull(value); + _server = value; + } + } + + /// Gets or sets the services associated with this request. + /// + /// This may not be the same instance stored in + /// if was true, in which case this + /// might be a scoped derived from the server's + /// . + /// + public IServiceProvider? Services { get; set; } + + /// Gets or sets the parameters associated with this request. + public TParams? Params { get; set; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index 187d05ea5..75797bb8a 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -25,7 +25,7 @@ public ClientServerTestBase(ITestOutputHelper testOutputHelper) .AddMcpServer() .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); ConfigureServices(sc, _builder); - ServiceProvider = sc.BuildServiceProvider(); + ServiceProvider = sc.BuildServiceProvider(validateScopes: true); _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); Server = ServiceProvider.GetRequiredService(); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index bee1ebc21..3e51a8f4a 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -757,26 +757,8 @@ internal class ToolTypeWithNoAttribute public static string MethodD(string d) => d.ToString(); } - public class ComplexObject - { - public string? Name { get; set; } - public int Age { get; set; } - } - public class ObjectWithId { public string Id { get; set; } = Guid.NewGuid().ToString("N"); } - - [JsonSerializable(typeof(bool))] - [JsonSerializable(typeof(int))] - [JsonSerializable(typeof(long))] - [JsonSerializable(typeof(double))] - [JsonSerializable(typeof(string))] - [JsonSerializable(typeof(DateTime))] - [JsonSerializable(typeof(DateTimeOffset))] - [JsonSerializable(typeof(ComplexObject))] - [JsonSerializable(typeof(string[]))] - [JsonSerializable(typeof(JsonElement))] - partial class JsonContext : JsonSerializerContext; } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs new file mode 100644 index 000000000..93573693e --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs @@ -0,0 +1,40 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Server; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Configuration; + +public partial class McpServerScopedTests : ClientServerTestBase +{ + public McpServerScopedTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder.WithTools(serializerOptions: JsonContext.Default.Options); + services.AddScoped(_ => new ComplexObject() { Name = "Scoped" }); + } + + [Fact] + public async Task InjectScopedServiceAsArgument() + { + IMcpClient client = await CreateMcpClientForServer(); + + var tools = await client.ListToolsAsync(JsonContext.Default.Options, TestContext.Current.CancellationToken); + var tool = tools.First(t => t.Name == nameof(EchoTool.EchoComplex)); + Assert.DoesNotContain("\"complex\"", JsonSerializer.Serialize(tool.JsonSchema, AIJsonUtilities.DefaultOptions)); + + Assert.Contains("\"Scoped\"", JsonSerializer.Serialize(await tool.InvokeAsync(cancellationToken: TestContext.Current.CancellationToken), AIJsonUtilities.DefaultOptions)); + } + + [McpServerToolType] + public sealed class EchoTool() + { + [McpServerTool] + public static string EchoComplex(ComplexObject complex) => complex.Name!; + } +} diff --git a/tests/ModelContextProtocol.Tests/JsonContext.cs b/tests/ModelContextProtocol.Tests/JsonContext.cs new file mode 100644 index 000000000..4714380bf --- /dev/null +++ b/tests/ModelContextProtocol.Tests/JsonContext.cs @@ -0,0 +1,20 @@ +using System.Text.Json; +using System.Text.Json.Serialization; + +public class ComplexObject +{ + public string? Name { get; set; } + public int Age { get; set; } +} + +[JsonSerializable(typeof(bool))] +[JsonSerializable(typeof(int))] +[JsonSerializable(typeof(long))] +[JsonSerializable(typeof(double))] +[JsonSerializable(typeof(string))] +[JsonSerializable(typeof(DateTime))] +[JsonSerializable(typeof(DateTimeOffset))] +[JsonSerializable(typeof(ComplexObject))] +[JsonSerializable(typeof(string[]))] +[JsonSerializable(typeof(JsonElement))] +partial class JsonContext : JsonSerializerContext; \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index cca61721d..e3ffd9d90 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -33,7 +33,7 @@ public async Task SupportsIMcpServer() Assert.DoesNotContain("server", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); var result = await prompt.GetAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Messages); @@ -59,16 +59,12 @@ public async Task SupportsServiceFromDI() Assert.Contains("something", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); - Mock mockServer = new(); - await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken)); - mockServer.SetupGet(x => x.Services).Returns(services); - var result = await prompt.GetAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(new Mock().Object) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("Hello", result.Messages[0].Content.Text); } @@ -89,7 +85,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await prompt.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("Hello", result.Messages[0].Content.Text); } @@ -102,7 +98,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() typeof(DisposablePromptType)); var result = await prompt1.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("disposals:1", result.Messages[0].Content.Text); } @@ -115,7 +111,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() typeof(AsyncDisposablePromptType)); var result = await prompt1.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("asyncDisposals:1", result.Messages[0].Content.Text); } @@ -128,7 +124,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable typeof(AsyncDisposableAndDisposablePromptType)); var result = await prompt1.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("disposals:0, asyncDisposals:1", result.Messages[0].Content.Text); } @@ -144,7 +140,7 @@ public async Task CanReturnGetPromptResult() }); var actual = await prompt.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Same(expected, actual); @@ -161,7 +157,7 @@ public async Task CanReturnText() }); var actual = await prompt.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -187,7 +183,7 @@ public async Task CanReturnPromptMessage() }); var actual = await prompt.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -218,7 +214,7 @@ public async Task CanReturnPromptMessages() }); var actual = await prompt.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -247,7 +243,7 @@ public async Task CanReturnChatMessage() }); var actual = await prompt.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -280,7 +276,7 @@ public async Task CanReturnChatMessages() }); var actual = await prompt.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -303,7 +299,7 @@ public async Task ThrowsForNullReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken)); } @@ -316,7 +312,7 @@ public async Task ThrowsForUnexpectedTypeReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 3d9f777ee..4c42ba242 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -40,7 +40,7 @@ public async Task SupportsIMcpServer() Assert.DoesNotContain("server", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema, McpJsonUtilities.DefaultOptions)); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.Equal("42", result.Content[0].Text); } @@ -92,14 +92,12 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Mock mockServer = new(); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.True(result.IsError); - mockServer.SetupGet(x => x.Services).Returns(services); - result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("42", result.Content[0].Text); } @@ -120,7 +118,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await tool.InvokeAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("42", result.Content[0].Text); } @@ -135,7 +133,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("""{"disposals":1}""", result.Content[0].Text); } @@ -150,7 +148,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1}""", result.Content[0].Text); } @@ -165,7 +163,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable options); var result = await tool1.InvokeAsync( - new RequestContext(null!, null), + new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", result.Content[0].Text); } @@ -186,7 +184,7 @@ public async Task CanReturnCollectionOfAIContent() }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.Equal(3, result.Content.Count); @@ -223,7 +221,7 @@ public async Task CanReturnSingleAIContent(string data, string type) }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.Single(result.Content); @@ -251,7 +249,7 @@ public async Task CanReturnNullAIContent() return (string?)null; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.Empty(result.Content); } @@ -266,7 +264,7 @@ public async Task CanReturnString() return "42"; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("42", result.Content[0].Text); @@ -283,7 +281,7 @@ public async Task CanReturnCollectionOfStrings() return new List() { "42", "43" }; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.Equal(2, result.Content.Count); Assert.Equal("42", result.Content[0].Text); @@ -302,7 +300,7 @@ public async Task CanReturnMcpContent() return new Content { Text = "42", Type = "text" }; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("42", result.Content[0].Text); @@ -319,7 +317,7 @@ public async Task CanReturnCollectionOfMcpContent() return new List() { new() { Text = "42", Type = "text" }, new() { Data = "1234", Type = "image", MimeType = "image/png" } }; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.Equal(2, result.Content.Count); Assert.Equal("42", result.Content[0].Text); @@ -345,7 +343,7 @@ public async Task CanReturnCallToolResponse() return response; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object, null), + new RequestContext(mockServer.Object), TestContext.Current.CancellationToken); Assert.Same(response, result); From da912e96fdba4ad8f8802f116753f013208cc2db Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 10 Apr 2025 23:51:34 -0400 Subject: [PATCH 2/3] Change request/notification delegates to return ValueTask --- .editorconfig | 3 + samples/EverythingServer/Program.cs | 53 +++++++-------- .../Client/McpClientExtensions.cs | 2 +- .../McpServerBuilderExtensions.cs | 22 +++---- src/ModelContextProtocol/IMcpEndpoint.cs | 2 +- .../Protocol/Types/ClientCapabilities.cs | 2 +- .../Protocol/Types/CompletionsCapability.cs | 2 +- .../Protocol/Types/EmptyResult.cs | 2 +- .../Protocol/Types/LoggingCapability.cs | 2 +- .../Protocol/Types/PromptsCapability.cs | 4 +- .../Protocol/Types/ResourcesCapability.cs | 10 +-- .../Protocol/Types/RootsCapability.cs | 2 +- .../Protocol/Types/SamplingCapability.cs | 2 +- .../Protocol/Types/ServerCapabilities.cs | 2 +- .../Protocol/Types/ToolsCapability.cs | 4 +- .../Server/AIFunctionMcpServerPrompt.cs | 2 +- .../Server/AIFunctionMcpServerTool.cs | 2 +- .../Server/DelegatingMcpServerPrompt.cs | 2 +- .../Server/DelegatingMcpServerTool.cs | 2 +- src/ModelContextProtocol/Server/McpServer.cs | 24 +++---- .../Server/McpServerHandlers.cs | 22 +++---- .../Server/McpServerPrompt.cs | 2 +- .../Server/McpServerTool.cs | 2 +- .../Shared/McpEndpoint.cs | 2 +- src/ModelContextProtocol/Shared/McpSession.cs | 2 +- .../Shared/NotificationHandlers.cs | 12 ++-- .../Shared/RequestHandlers.cs | 2 +- .../SseIntegrationTests.cs | 2 +- .../Program.cs | 65 ++++++++++--------- .../Program.cs | 44 +++++++------ .../Client/McpClientExtensionsTests.cs | 2 +- .../Client/McpClientFactoryTests.cs | 9 +-- .../ClientIntegrationTests.cs | 12 ++-- .../McpServerBuilderExtensionsHandlerTests.cs | 20 +++--- .../McpServerBuilderExtensionsPromptsTests.cs | 4 +- .../McpServerBuilderExtensionsToolsTests.cs | 6 +- .../DockerEverythingServerTests.cs | 6 +- .../Protocol/CancellationTests.cs | 2 +- .../Protocol/NotificationHandlerTests.cs | 4 +- .../Server/McpServerDelegatesTests.cs | 20 +++--- .../Server/McpServerLoggingLevelTests.cs | 10 +-- .../Server/McpServerResourceTests.cs | 30 ++++----- .../Server/McpServerTests.cs | 54 +++++++-------- 43 files changed, 242 insertions(+), 238 deletions(-) diff --git a/.editorconfig b/.editorconfig index c26008819..3ce6343ba 100644 --- a/.editorconfig +++ b/.editorconfig @@ -4,6 +4,9 @@ root = true # C# files [*.cs] +# Compiler +dotnet_diagnostic.CS1998.severity = suggestion # CS1998: Missing awaits + # Code Analysis dotnet_diagnostic.CA1002.severity = none # CA1002: Do not expose generic lists dotnet_diagnostic.CA1031.severity = none # CA1031: Do not catch general exception types diff --git a/samples/EverythingServer/Program.cs b/samples/EverythingServer/Program.cs index 56b41de53..a0966fe75 100644 --- a/samples/EverythingServer/Program.cs +++ b/samples/EverythingServer/Program.cs @@ -9,6 +9,8 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + var builder = Host.CreateApplicationBuilder(args); builder.Logging.AddConsole(consoleLogOptions => { @@ -31,17 +33,17 @@ .WithTools() .WithPrompts() .WithPrompts() - .WithListResourceTemplatesHandler((ctx, ct) => + .WithListResourceTemplatesHandler(async (ctx, ct) => { - return Task.FromResult(new ListResourceTemplatesResult + return new ListResourceTemplatesResult { ResourceTemplates = [ new ResourceTemplate { Name = "Static Resource", Description = "A static resource with a numeric ID", UriTemplate = "test://static/resource/{id}" } ] - }); + }; }) - .WithReadResourceHandler((ctx, ct) => + .WithReadResourceHandler(async (ctx, ct) => { var uri = ctx.Params?.Uri; @@ -61,7 +63,7 @@ if (resource.MimeType == "text/plain") { - return Task.FromResult(new ReadResourceResult + return new ReadResourceResult { Contents = [new TextResourceContents { @@ -69,11 +71,11 @@ MimeType = resource.MimeType, Uri = resource.Uri, }] - }); + }; } else { - return Task.FromResult(new ReadResourceResult + return new ReadResourceResult { Contents = [new BlobResourceContents { @@ -81,7 +83,7 @@ MimeType = resource.MimeType, Uri = resource.Uri, }] - }); + }; } }) .WithSubscribeToResourcesHandler(async (ctx, ct) => @@ -106,16 +108,16 @@ await ctx.Server.RequestSamplingAsync([ return new EmptyResult(); }) - .WithUnsubscribeFromResourcesHandler((ctx, ct) => + .WithUnsubscribeFromResourcesHandler(async (ctx, ct) => { var uri = ctx.Params?.Uri; if (uri is not null) { subscriptions.Remove(uri); } - return Task.FromResult(new EmptyResult()); + return new EmptyResult(); }) - .WithCompleteHandler((ctx, ct) => + .WithCompleteHandler(async (ctx, ct) => { var exampleCompletions = new Dictionary> { @@ -128,7 +130,7 @@ await ctx.Server.RequestSamplingAsync([ { throw new NotSupportedException($"Params are required."); } - + var @ref = @params.Ref; var argument = @params.Argument; @@ -138,15 +140,15 @@ await ctx.Server.RequestSamplingAsync([ if (resourceId is null) { - return Task.FromResult(new CompleteResult()); + return new CompleteResult(); } var values = exampleCompletions["resourceId"].Where(id => id.StartsWith(argument.Value)); - return Task.FromResult(new CompleteResult + return new CompleteResult { - Completion = new Completion { Values = [..values], HasMore = false, Total = values.Count() } - }); + Completion = new Completion { Values = [.. values], HasMore = false, Total = values.Count() } + }; } if (@ref.Type == "ref/prompt") @@ -157,10 +159,10 @@ await ctx.Server.RequestSamplingAsync([ } var values = value.Where(value => value.StartsWith(argument.Value)); - return Task.FromResult(new CompleteResult + return new CompleteResult { - Completion = new Completion { Values = [..values], HasMore = false, Total = values.Count() } - }); + Completion = new Completion { Values = [.. values], HasMore = false, Total = values.Count() } + }; } throw new NotSupportedException($"Unknown reference type: {@ref.Type}"); @@ -175,15 +177,14 @@ await ctx.Server.RequestSamplingAsync([ _minimumLoggingLevel = ctx.Params.Level; await ctx.Server.SendNotificationAsync("notifications/message", new - { - Level = "debug", - Logger = "test-server", - Data = $"Logging level set to {_minimumLoggingLevel}", - }, cancellationToken: ct); + { + Level = "debug", + Logger = "test-server", + Data = $"Logging level set to {_minimumLoggingLevel}", + }, cancellationToken: ct); return new EmptyResult(); - }) - ; + }); builder.Services.AddSingleton(subscriptions); builder.Services.AddHostedService(); diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index aaaa4263c..f495e89cb 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -948,7 +948,7 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat /// /// /// is . - public static Func, CancellationToken, Task> CreateSamplingHandler( + public static Func, CancellationToken, ValueTask> CreateSamplingHandler( this IChatClient chatClient) { Throw.IfNull(chatClient); diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs index 521c4a191..05872aeb5 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs @@ -275,7 +275,7 @@ where t.GetCustomAttribute() is not null /// resource system where templates define the URI patterns and the read handler provides the actual content. /// /// - public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -308,7 +308,7 @@ public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServer /// executes them when invoked by clients. /// /// - public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -328,7 +328,7 @@ public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder buil /// This method is typically paired with to provide a complete tools implementation, /// where advertises available tools and this handler executes them. /// - public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -361,7 +361,7 @@ public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder build /// produces them when invoked by clients. /// /// - public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -376,7 +376,7 @@ public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder bu /// The handler function that processes prompt requests. /// The builder provided in . /// is . - public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -397,7 +397,7 @@ public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder buil /// where this handler advertises available resources and the read handler provides their content when requested. /// /// - public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -416,7 +416,7 @@ public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder /// This handler is typically paired with to provide a complete resources implementation, /// where the list handler advertises available resources and the read handler provides their content when requested. /// - public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -435,7 +435,7 @@ public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder b /// The completion handler is invoked when clients request suggestions for argument values. /// This enables auto-complete functionality for both prompt arguments and resource references. /// - public static IMcpServerBuilder WithCompleteHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithCompleteHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -465,7 +465,7 @@ public static IMcpServerBuilder WithCompleteHandler(this IMcpServerBuilder build /// resources and to send appropriate notifications through the connection when resources change. /// /// - public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -495,7 +495,7 @@ public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerB /// to the specified resource. /// /// - public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); @@ -522,7 +522,7 @@ public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpSer /// most recently set level. /// /// - public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) { Throw.IfNull(builder); diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs index fede1b69d..a3b941cce 100644 --- a/src/ModelContextProtocol/IMcpEndpoint.cs +++ b/src/ModelContextProtocol/IMcpEndpoint.cs @@ -69,5 +69,5 @@ public interface IMcpEndpoint : IAsyncDisposable /// The notification method. /// The handler to be invoked. /// An that will remove the registered handler when disposed. - IAsyncDisposable RegisterNotificationHandler(string method, Func handler); + IAsyncDisposable RegisterNotificationHandler(string method, Func handler); } diff --git a/src/ModelContextProtocol/Protocol/Types/ClientCapabilities.cs b/src/ModelContextProtocol/Protocol/Types/ClientCapabilities.cs index a033193ac..3579232d8 100644 --- a/src/ModelContextProtocol/Protocol/Types/ClientCapabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/ClientCapabilities.cs @@ -77,5 +77,5 @@ public class ClientCapabilities /// /// [JsonIgnore] - public IEnumerable>>? NotificationHandlers { get; set; } + public IEnumerable>>? NotificationHandlers { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/CompletionsCapability.cs b/src/ModelContextProtocol/Protocol/Types/CompletionsCapability.cs index 5c54307f7..70dc460fe 100644 --- a/src/ModelContextProtocol/Protocol/Types/CompletionsCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/CompletionsCapability.cs @@ -33,5 +33,5 @@ public class CompletionsCapability /// and should return appropriate completion suggestions. /// [JsonIgnore] - public Func, CancellationToken, Task>? CompleteHandler { get; set; } + public Func, CancellationToken, ValueTask>? CompleteHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/EmptyResult.cs b/src/ModelContextProtocol/Protocol/Types/EmptyResult.cs index 98e37b5ab..402472b3c 100644 --- a/src/ModelContextProtocol/Protocol/Types/EmptyResult.cs +++ b/src/ModelContextProtocol/Protocol/Types/EmptyResult.cs @@ -9,5 +9,5 @@ namespace ModelContextProtocol.Protocol.Types; public class EmptyResult { [JsonIgnore] - internal static Task CompletedTask { get; } = Task.FromResult(new EmptyResult()); + internal static EmptyResult Instance { get; } = new(); } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/LoggingCapability.cs b/src/ModelContextProtocol/Protocol/Types/LoggingCapability.cs index f52baade3..18e60c424 100644 --- a/src/ModelContextProtocol/Protocol/Types/LoggingCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/LoggingCapability.cs @@ -19,5 +19,5 @@ public class LoggingCapability /// Gets or sets the handler for set logging level requests from clients. /// [JsonIgnore] - public Func, CancellationToken, Task>? SetLoggingLevelHandler { get; set; } + public Func, CancellationToken, ValueTask>? SetLoggingLevelHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/PromptsCapability.cs b/src/ModelContextProtocol/Protocol/Types/PromptsCapability.cs index 0fcc216bc..53aa8043f 100644 --- a/src/ModelContextProtocol/Protocol/Types/PromptsCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/PromptsCapability.cs @@ -41,7 +41,7 @@ public class PromptsCapability /// along with any prompts defined in . /// [JsonIgnore] - public Func, CancellationToken, Task>? ListPromptsHandler { get; set; } + public Func, CancellationToken, ValueTask>? ListPromptsHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -58,7 +58,7 @@ public class PromptsCapability /// /// [JsonIgnore] - public Func, CancellationToken, Task>? GetPromptHandler { get; set; } + public Func, CancellationToken, ValueTask>? GetPromptHandler { get; set; } /// /// Gets or sets a collection of prompts that will be served by the server. diff --git a/src/ModelContextProtocol/Protocol/Types/ResourcesCapability.cs b/src/ModelContextProtocol/Protocol/Types/ResourcesCapability.cs index 3219f02b1..3bb76378e 100644 --- a/src/ModelContextProtocol/Protocol/Types/ResourcesCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/ResourcesCapability.cs @@ -40,7 +40,7 @@ public class ResourcesCapability /// allowing clients to discover available resource types and their access patterns. /// [JsonIgnore] - public Func, CancellationToken, Task>? ListResourceTemplatesHandler { get; set; } + public Func, CancellationToken, ValueTask>? ListResourceTemplatesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -50,7 +50,7 @@ public class ResourcesCapability /// The implementation should return a with the matching resources. /// [JsonIgnore] - public Func, CancellationToken, Task>? ListResourcesHandler { get; set; } + public Func, CancellationToken, ValueTask>? ListResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -62,7 +62,7 @@ public class ResourcesCapability /// its contents in a ReadResourceResult object. /// [JsonIgnore] - public Func, CancellationToken, Task>? ReadResourceHandler { get; set; } + public Func, CancellationToken, ValueTask>? ReadResourceHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -75,7 +75,7 @@ public class ResourcesCapability /// requiring polling. /// [JsonIgnore] - public Func, CancellationToken, Task>? SubscribeToResourcesHandler { get; set; } + public Func, CancellationToken, ValueTask>? SubscribeToResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -86,5 +86,5 @@ public class ResourcesCapability /// about the specified resource. /// [JsonIgnore] - public Func, CancellationToken, Task>? UnsubscribeFromResourcesHandler { get; set; } + public Func, CancellationToken, ValueTask>? UnsubscribeFromResourcesHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/RootsCapability.cs b/src/ModelContextProtocol/Protocol/Types/RootsCapability.cs index ed51115e0..d087cee7b 100644 --- a/src/ModelContextProtocol/Protocol/Types/RootsCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/RootsCapability.cs @@ -41,5 +41,5 @@ public class RootsCapability /// The handler receives request parameters and should return a containing the collection of available roots. /// [JsonIgnore] - public Func>? RootsHandler { get; set; } + public Func>? RootsHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/SamplingCapability.cs b/src/ModelContextProtocol/Protocol/Types/SamplingCapability.cs index 55db6f247..96d9a6443 100644 --- a/src/ModelContextProtocol/Protocol/Types/SamplingCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/SamplingCapability.cs @@ -40,5 +40,5 @@ public class SamplingCapability /// /// [JsonIgnore] - public Func, CancellationToken, Task>? SamplingHandler { get; set; } + public Func, CancellationToken, ValueTask>? SamplingHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs index 9a6c7dc9e..6406ea4dc 100644 --- a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs @@ -83,5 +83,5 @@ public class ServerCapabilities /// /// [JsonIgnore] - public IEnumerable>>? NotificationHandlers { get; set; } + public IEnumerable>>? NotificationHandlers { get; set; } } diff --git a/src/ModelContextProtocol/Protocol/Types/ToolsCapability.cs b/src/ModelContextProtocol/Protocol/Types/ToolsCapability.cs index 1e20e741b..87554398f 100644 --- a/src/ModelContextProtocol/Protocol/Types/ToolsCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/ToolsCapability.cs @@ -34,7 +34,7 @@ public class ToolsCapability /// and the tools from the collection will be combined to form the complete list of available tools. /// [JsonIgnore] - public Func, CancellationToken, Task>? ListToolsHandler { get; set; } + public Func, CancellationToken, ValueTask>? ListToolsHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -46,7 +46,7 @@ public class ToolsCapability /// being called and its arguments, and should return a with the execution results. /// [JsonIgnore] - public Func, CancellationToken, Task>? CallToolHandler { get; set; } + public Func, CancellationToken, ValueTask>? CallToolHandler { get; set; } /// /// Gets or sets a collection of tools served by the server. diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs index fbdf49520..85bf38cad 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs @@ -206,7 +206,7 @@ private AIFunctionMcpServerPrompt(AIFunction function, Prompt prompt) /// objects are converted to prompt messages /// /// - public override async Task GetAsync( + public override async ValueTask GetAsync( RequestContext request, CancellationToken cancellationToken = default) { Throw.IfNull(request); diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 6fbf68376..76f0d6e5a 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -237,7 +237,7 @@ private AIFunctionMcpServerTool(AIFunction function, Tool tool) public override Tool ProtocolTool { get; } /// - public override async Task InvokeAsync( + public override async ValueTask InvokeAsync( RequestContext request, CancellationToken cancellationToken = default) { Throw.IfNull(request); diff --git a/src/ModelContextProtocol/Server/DelegatingMcpServerPrompt.cs b/src/ModelContextProtocol/Server/DelegatingMcpServerPrompt.cs index 7cf052eb6..a209bea26 100644 --- a/src/ModelContextProtocol/Server/DelegatingMcpServerPrompt.cs +++ b/src/ModelContextProtocol/Server/DelegatingMcpServerPrompt.cs @@ -24,7 +24,7 @@ protected DelegatingMcpServerPrompt(McpServerPrompt innerPrompt) public override Prompt ProtocolPrompt => _innerPrompt.ProtocolPrompt; /// - public override Task GetAsync( + public override ValueTask GetAsync( RequestContext request, CancellationToken cancellationToken = default) => _innerPrompt.GetAsync(request, cancellationToken); diff --git a/src/ModelContextProtocol/Server/DelegatingMcpServerTool.cs b/src/ModelContextProtocol/Server/DelegatingMcpServerTool.cs index d4555d714..2a4878a46 100644 --- a/src/ModelContextProtocol/Server/DelegatingMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/DelegatingMcpServerTool.cs @@ -24,7 +24,7 @@ protected DelegatingMcpServerTool(McpServerTool innerTool) public override Tool ProtocolTool => _innerTool.ProtocolTool; /// - public override Task InvokeAsync( + public override ValueTask InvokeAsync( RequestContext request, CancellationToken cancellationToken = default) => _innerTool.InvokeAsync(request, cancellationToken); diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 5696efd8a..9b0559ee2 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -8,6 +8,8 @@ using ModelContextProtocol.Utils.Json; using System.Runtime.CompilerServices; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + namespace ModelContextProtocol.Server; /// @@ -157,7 +159,7 @@ public override async ValueTask DisposeUnsynchronizedAsync() private void SetPingHandler() { RequestHandlers.Set(RequestMethods.Ping, - (request, _) => Task.FromResult(new PingResult()), + async (request, _) => new PingResult(), McpJsonUtilities.JsonContext.Default.JsonNode, McpJsonUtilities.JsonContext.Default.PingResult); } @@ -165,7 +167,7 @@ private void SetPingHandler() private void SetInitializeHandler(McpServerOptions options) { RequestHandlers.Set(RequestMethods.Initialize, - (request, _) => + async (request, _) => { ClientCapabilities = request?.Capabilities ?? new(); ClientInfo = request?.ClientInfo; @@ -174,13 +176,13 @@ private void SetInitializeHandler(McpServerOptions options) _endpointName = $"{_endpointName}, Client ({ClientInfo?.Name} {ClientInfo?.Version})"; GetSessionOrThrow().EndpointName = _endpointName; - return Task.FromResult(new InitializeResult + return new InitializeResult { ProtocolVersion = options.ProtocolVersion, Instructions = options.ServerInstructions, ServerInfo = options.ServerInfo ?? DefaultImplementation, Capabilities = ServerCapabilities ?? new(), - }); + }; }, McpJsonUtilities.JsonContext.Default.InitializeRequestParams, McpJsonUtilities.JsonContext.Default.InitializeResult); @@ -220,7 +222,7 @@ private void SetResourcesHandler(McpServerOptions options) throw new McpException("Resources capability was enabled, but ListResources and/or ReadResource handlers were not specified."); } - listResourcesHandler ??= (static (_, _) => Task.FromResult(new ListResourcesResult())); + listResourcesHandler ??= static async (_, _) => new ListResourcesResult(); RequestHandlers.Set( RequestMethods.ResourcesList, @@ -234,7 +236,7 @@ private void SetResourcesHandler(McpServerOptions options) McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, McpJsonUtilities.JsonContext.Default.ReadResourceResult); - listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult())); + listResourceTemplatesHandler ??= static async (_, _) => new ListResourceTemplatesResult(); RequestHandlers.Set( RequestMethods.ResourcesTemplatesList, (request, cancellationToken) => InvokeHandlerAsync(listResourceTemplatesHandler, request, cancellationToken), @@ -480,14 +482,14 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) } // Otherwise, consider it handled. - return EmptyResult.CompletedTask; + return new ValueTask(EmptyResult.Instance); }, McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); } - private Task InvokeHandlerAsync( - Func, CancellationToken, Task> handler, + private ValueTask InvokeHandlerAsync( + Func, CancellationToken, ValueTask> handler, TParams? args, CancellationToken cancellationToken) { @@ -495,8 +497,8 @@ private Task InvokeHandlerAsync( InvokeScopedAsync(handler, args, cancellationToken) : handler(new(this) { Params = args }, cancellationToken); - async Task InvokeScopedAsync( - Func, CancellationToken, Task> handler, + async ValueTask InvokeScopedAsync( + Func, CancellationToken, ValueTask> handler, TParams? args, CancellationToken cancellationToken) { diff --git a/src/ModelContextProtocol/Server/McpServerHandlers.cs b/src/ModelContextProtocol/Server/McpServerHandlers.cs index b3fdf5af4..b749ab0fe 100644 --- a/src/ModelContextProtocol/Server/McpServerHandlers.cs +++ b/src/ModelContextProtocol/Server/McpServerHandlers.cs @@ -41,7 +41,7 @@ public sealed class McpServerHandlers /// Tools from both sources will be combined when returning results to clients. /// /// - public Func, CancellationToken, Task>? ListToolsHandler { get; set; } + public Func, CancellationToken, ValueTask>? ListToolsHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -50,7 +50,7 @@ public sealed class McpServerHandlers /// This handler is invoked when a client makes a call to a tool that isn't found in the collection. /// The handler should implement logic to execute the requested tool and return appropriate results. /// - public Func, CancellationToken, Task>? CallToolHandler { get; set; } + public Func, CancellationToken, ValueTask>? CallToolHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -66,7 +66,7 @@ public sealed class McpServerHandlers /// Prompts from both sources will be combined when returning results to clients. /// /// - public Func, CancellationToken, Task>? ListPromptsHandler { get; set; } + public Func, CancellationToken, ValueTask>? ListPromptsHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -75,7 +75,7 @@ public sealed class McpServerHandlers /// This handler is invoked when a client requests details for a specific prompt that isn't found in the collection. /// The handler should implement logic to fetch or generate the requested prompt and return appropriate results. /// - public Func, CancellationToken, Task>? GetPromptHandler { get; set; } + public Func, CancellationToken, ValueTask>? GetPromptHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -85,7 +85,7 @@ public sealed class McpServerHandlers /// It supports pagination through the cursor mechanism, where the client can make /// repeated calls with the cursor returned by the previous call to retrieve more resource templates. /// - public Func, CancellationToken, Task>? ListResourceTemplatesHandler { get; set; } + public Func, CancellationToken, ValueTask>? ListResourceTemplatesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -95,7 +95,7 @@ public sealed class McpServerHandlers /// It supports pagination through the cursor mechanism, where the client can make /// repeated calls with the cursor returned by the previous call to retrieve more resources. /// - public Func, CancellationToken, Task>? ListResourcesHandler { get; set; } + public Func, CancellationToken, ValueTask>? ListResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -104,7 +104,7 @@ public sealed class McpServerHandlers /// This handler is invoked when a client requests the content of a specific resource identified by its URI. /// The handler should implement logic to locate and retrieve the requested resource. /// - public Func, CancellationToken, Task>? ReadResourceHandler { get; set; } + public Func, CancellationToken, ValueTask>? ReadResourceHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -114,7 +114,7 @@ public sealed class McpServerHandlers /// The handler processes auto-completion requests, returning a list of suggestions based on the /// reference type and current argument value. /// - public Func, CancellationToken, Task>? CompleteHandler { get; set; } + public Func, CancellationToken, ValueTask>? CompleteHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -130,7 +130,7 @@ public sealed class McpServerHandlers /// whenever a relevant resource is created, updated, or deleted. /// /// - public Func, CancellationToken, Task>? SubscribeToResourcesHandler { get; set; } + public Func, CancellationToken, ValueTask>? SubscribeToResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -146,7 +146,7 @@ public sealed class McpServerHandlers /// to the client for the specified resources. /// /// - public Func, CancellationToken, Task>? UnsubscribeFromResourcesHandler { get; set; } + public Func, CancellationToken, ValueTask>? UnsubscribeFromResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -161,7 +161,7 @@ public sealed class McpServerHandlers /// at or above the specified level to the client as notifications/message notifications. /// /// - public Func, CancellationToken, Task>? SetLoggingLevelHandler { get; set; } + public Func, CancellationToken, ValueTask>? SetLoggingLevelHandler { get; set; } /// /// Overwrite any handlers in McpServerOptions with non-null handlers from this instance. diff --git a/src/ModelContextProtocol/Server/McpServerPrompt.cs b/src/ModelContextProtocol/Server/McpServerPrompt.cs index 225ff24b5..f1c7187fb 100644 --- a/src/ModelContextProtocol/Server/McpServerPrompt.cs +++ b/src/ModelContextProtocol/Server/McpServerPrompt.cs @@ -147,7 +147,7 @@ protected McpServerPrompt() /// /// is . /// The prompt implementation returns or an unsupported result type. - public abstract Task GetAsync( + public abstract ValueTask GetAsync( RequestContext request, CancellationToken cancellationToken = default); diff --git a/src/ModelContextProtocol/Server/McpServerTool.cs b/src/ModelContextProtocol/Server/McpServerTool.cs index 9a213eb75..f169c64c7 100644 --- a/src/ModelContextProtocol/Server/McpServerTool.cs +++ b/src/ModelContextProtocol/Server/McpServerTool.cs @@ -150,7 +150,7 @@ protected McpServerTool() /// The to monitor for cancellation requests. The default is . /// The call response from invoking the tool. /// is . - public abstract Task InvokeAsync( + public abstract ValueTask InvokeAsync( RequestContext request, CancellationToken cancellationToken = default); diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index 6372c2ef3..bbe6fbd4a 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -49,7 +49,7 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => GetSessionOrThrow().SendMessageAsync(message, cancellationToken); - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => GetSessionOrThrow().RegisterNotificationHandler(method, handler); /// diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 043680b42..062d8cf44 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -314,7 +314,7 @@ private CancellationTokenRegistration RegisterCancellation(CancellationToken can }, Tuple.Create(this, requestId)); } - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) { Throw.IfNullOrWhiteSpace(method); Throw.IfNull(handler); diff --git a/src/ModelContextProtocol/Shared/NotificationHandlers.cs b/src/ModelContextProtocol/Shared/NotificationHandlers.cs index 70941a9cc..8e315ce09 100644 --- a/src/ModelContextProtocol/Shared/NotificationHandlers.cs +++ b/src/ModelContextProtocol/Shared/NotificationHandlers.cs @@ -38,7 +38,7 @@ internal sealed class NotificationHandlers /// with the corresponding method name is received. /// /// - public void RegisterRange(IEnumerable>> handlers) + public void RegisterRange(IEnumerable>> handlers) { foreach (var entry in handlers) { @@ -64,7 +64,7 @@ public void RegisterRange(IEnumerable public IAsyncDisposable Register( - string method, Func handler, bool temporary = true) + string method, Func handler, bool temporary = true) { // Create the new registration instance. Registration reg = new(this, method, handler, temporary); @@ -143,7 +143,7 @@ public async Task InvokeHandlers(string method, JsonRpcNotification notification /// Provides storage for a handler registration. private sealed class Registration( - NotificationHandlers handlers, string method, Func handler, bool unregisterable) : IAsyncDisposable + NotificationHandlers handlers, string method, Func handler, bool unregisterable) : IAsyncDisposable { /// Used to prevent deadlocks during disposal. /// @@ -165,7 +165,7 @@ private sealed class Registration( private readonly string _method = method; /// The handler this registration represents. - private readonly Func _handler = handler; + private readonly Func _handler = handler; /// true if this instance is temporary; false if it's permanent private readonly bool _temporary = unregisterable; @@ -272,7 +272,7 @@ public async ValueTask DisposeAsync() } /// Invoke the handler associated with the registration. - public Task InvokeAsync(JsonRpcNotification notification, CancellationToken cancellationToken) + public ValueTask InvokeAsync(JsonRpcNotification notification, CancellationToken cancellationToken) { // For permanent registrations, skip all the tracking overhead and just invoke the handler. if (!_temporary) @@ -285,7 +285,7 @@ public Task InvokeAsync(JsonRpcNotification notification, CancellationToken canc } /// Invoke the handler associated with the temporary registration. - private async Task InvokeTemporaryAsync(JsonRpcNotification notification, CancellationToken cancellationToken) + private async ValueTask InvokeTemporaryAsync(JsonRpcNotification notification, CancellationToken cancellationToken) { // Check whether we need to handle this registration. If DisposeAsync has been called, // then even if there are in-flight invocations for it, we avoid adding more. diff --git a/src/ModelContextProtocol/Shared/RequestHandlers.cs b/src/ModelContextProtocol/Shared/RequestHandlers.cs index b41317118..184fd9077 100644 --- a/src/ModelContextProtocol/Shared/RequestHandlers.cs +++ b/src/ModelContextProtocol/Shared/RequestHandlers.cs @@ -30,7 +30,7 @@ internal sealed class RequestHandlers : Dictionary public void Set( string method, - Func> handler, + Func> handler, JsonTypeInfo requestTypeInfo, JsonTypeInfo responseTypeInfo) { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 83c601aa4..50d54dbb3 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -87,7 +87,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() { var msg = args.Params?["message"]?.GetValue(); receivedNotification.SetResult(msg); - return Task.CompletedTask; + return default; }); // Send a test message through POST endpoint diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index ee97f7bb7..4fe963d26 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -8,6 +8,8 @@ using System.Text; using System.Text.Json; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + namespace ModelContextProtocol.TestServer; internal static class Program @@ -110,9 +112,9 @@ private static ToolsCapability ConfigureTools() { return new() { - ListToolsHandler = (request, cancellationToken) => + ListToolsHandler = async (request, cancellationToken) => { - return Task.FromResult(new ListToolsResult() + return new ListToolsResult() { Tools = [ @@ -155,7 +157,7 @@ private static ToolsCapability ConfigureTools() """), } ] - }); + }; }, CallToolHandler = async (request, cancellationToken) => @@ -199,9 +201,9 @@ private static PromptsCapability ConfigurePrompts() { return new() { - ListPromptsHandler = (request, cancellationToken) => + ListPromptsHandler = async (request, cancellationToken) => { - return Task.FromResult(new ListPromptsResult() + return new ListPromptsResult() { Prompts = [ new Prompt() @@ -230,10 +232,10 @@ private static PromptsCapability ConfigurePrompts() ] } ] - }); + }; }, - GetPromptHandler = (request, cancellationToken) => + GetPromptHandler = async (request, cancellationToken) => { List messages = []; if (request.Params?.Name == "simple_prompt") @@ -286,10 +288,10 @@ private static PromptsCapability ConfigurePrompts() throw new McpException($"Unknown prompt: {request.Params?.Name}"); } - return Task.FromResult(new GetPromptResult() + return new GetPromptResult() { Messages = messages - }); + }; } }; } @@ -300,7 +302,7 @@ private static LoggingCapability ConfigureLogging() { return new() { - SetLoggingLevelHandler = (request, cancellationToken) => + SetLoggingLevelHandler = async (request, cancellationToken) => { if (request.Params?.Level is null) { @@ -309,7 +311,7 @@ private static LoggingCapability ConfigureLogging() _minimumLoggingLevel = request.Params.Level; - return Task.FromResult(new EmptyResult()); + return new EmptyResult(); } }; } @@ -360,9 +362,9 @@ private static ResourcesCapability ConfigureResources() return new() { - ListResourceTemplatesHandler = (request, cancellationToken) => + ListResourceTemplatesHandler = async (request, cancellationToken) => { - return Task.FromResult(new ListResourceTemplatesResult() + return new ListResourceTemplatesResult() { ResourceTemplates = [ new ResourceTemplate() @@ -371,10 +373,10 @@ private static ResourcesCapability ConfigureResources() Name = "Dynamic Resource", } ] - }); + }; }, - ListResourcesHandler = (request, cancellationToken) => + ListResourcesHandler = async (request, cancellationToken) => { int startIndex = 0; if (request.Params?.Cursor is not null) @@ -397,14 +399,14 @@ private static ResourcesCapability ConfigureResources() { nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); } - return Task.FromResult(new ListResourcesResult() + return new ListResourcesResult() { NextCursor = nextCursor, Resources = resources.GetRange(startIndex, endIndex - startIndex) - }); + }; }, - ReadResourceHandler = (request, cancellationToken) => + ReadResourceHandler = async (request, cancellationToken) => { if (request.Params?.Uri is null) { @@ -418,7 +420,8 @@ private static ResourcesCapability ConfigureResources() { throw new McpException("Invalid resource URI"); } - return Task.FromResult(new ReadResourceResult() + + return new ReadResourceResult() { Contents = [ new TextResourceContents() @@ -428,19 +431,19 @@ private static ResourcesCapability ConfigureResources() Text = $"Dynamic resource {id}: This is a plaintext resource" } ] - }); + }; } ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) ?? throw new McpException("Resource not found"); - return Task.FromResult(new ReadResourceResult() + return new ReadResourceResult() { Contents = [contents] - }); + }; }, - SubscribeToResourcesHandler = (request, cancellationToken) => + SubscribeToResourcesHandler = async (request, cancellationToken) => { if (request?.Params?.Uri is null) { @@ -454,10 +457,10 @@ private static ResourcesCapability ConfigureResources() _subscribedResources.TryAdd(request.Params.Uri, true); - return Task.FromResult(new EmptyResult()); + return new EmptyResult(); }, - UnsubscribeFromResourcesHandler = (request, cancellationToken) => + UnsubscribeFromResourcesHandler = async (request, cancellationToken) => { if (request?.Params?.Uri is null) { @@ -471,7 +474,7 @@ private static ResourcesCapability ConfigureResources() _subscribedResources.TryRemove(request.Params.Uri, out _); - return Task.FromResult(new EmptyResult()); + return new EmptyResult(); }, Subscribe = true @@ -487,17 +490,17 @@ private static CompletionsCapability ConfigureCompletions() {"temperature", ["0", "0.5", "0.7", "1.0"]}, }; - Func, CancellationToken, Task> handler = (request, cancellationToken) => + Func, CancellationToken, ValueTask> handler = async (request, cancellationToken) => { if (request.Params?.Ref?.Type == "ref/resource") { var resourceId = request.Params?.Ref?.Uri?.Split('/').LastOrDefault(); if (string.IsNullOrEmpty(resourceId)) - return Task.FromResult(new CompleteResult() { Completion = new() { Values = [] } }); + return new CompleteResult() { Completion = new() { Values = [] } }; // Filter resource IDs that start with the input value var values = sampleResourceIds.Where(id => id.StartsWith(request.Params!.Argument.Value)).ToArray(); - return Task.FromResult(new CompleteResult() { Completion = new() { Values = values, HasMore = false, Total = values.Length } }); + return new CompleteResult() { Completion = new() { Values = values, HasMore = false, Total = values.Length } }; } @@ -505,10 +508,10 @@ private static CompletionsCapability ConfigureCompletions() { // Handle completion for prompt arguments if (!exampleCompletions.TryGetValue(request.Params.Argument.Name, out var completions)) - return Task.FromResult(new CompleteResult() { Completion = new() { Values = [] } }); + return new CompleteResult() { Completion = new() { Values = [] } }; var values = completions.Where(value => value.StartsWith(request.Params.Argument.Value)).ToArray(); - return Task.FromResult(new CompleteResult() { Completion = new() { Values = values, HasMore = false, Total = values.Length } }); + return new CompleteResult() { Completion = new() { Values = values, HasMore = false, Total = values.Length } }; } throw new McpException($"Unknown reference type: {request.Params?.Ref.Type}"); diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index f364c4a12..37ecc1a1c 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -6,6 +6,8 @@ using System.Text; using System.Text.Json; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + namespace ModelContextProtocol.TestSseServer; public class Program @@ -104,9 +106,9 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { Tools = new() { - ListToolsHandler = (request, cancellationToken) => + ListToolsHandler = async (request, cancellationToken) => { - return Task.FromResult(new ListToolsResult() + return new ListToolsResult() { Tools = [ @@ -149,7 +151,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st """, McpJsonUtilities.DefaultOptions), } ] - }); + }; }, CallToolHandler = async (request, cancellationToken) => { @@ -192,10 +194,10 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }, Resources = new() { - ListResourceTemplatesHandler = (request, cancellationToken) => + ListResourceTemplatesHandler = async (request, cancellationToken) => { - return Task.FromResult(new ListResourceTemplatesResult() + return new ListResourceTemplatesResult() { ResourceTemplates = [ new ResourceTemplate() @@ -204,10 +206,10 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st Name = "Dynamic Resource", } ] - }); + }; }, - ListResourcesHandler = (request, cancellationToken) => + ListResourcesHandler = async (request, cancellationToken) => { int startIndex = 0; var requestParams = request.Params ?? new(); @@ -231,13 +233,14 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); } - return Task.FromResult(new ListResourcesResult() + + return new ListResourcesResult() { NextCursor = nextCursor, Resources = resources.GetRange(startIndex, endIndex - startIndex) - }); + }; }, - ReadResourceHandler =(request, cancellationToken) => + ReadResourceHandler = async (request, cancellationToken) => { if (request.Params?.Uri is null) { @@ -251,7 +254,8 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { throw new McpException("Invalid resource URI"); } - return Task.FromResult(new ReadResourceResult() + + return new ReadResourceResult() { Contents = [ new TextResourceContents() @@ -261,23 +265,23 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st Text = $"Dynamic resource {id}: This is a plaintext resource" } ] - }); + }; } ResourceContents? contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) ?? throw new McpException("Resource not found"); - return Task.FromResult(new ReadResourceResult() + return new ReadResourceResult() { Contents = [contents] - }); + }; } }, Prompts = new() { - ListPromptsHandler = (request, cancellationToken) => + ListPromptsHandler = async (request, cancellationToken) => { - return Task.FromResult(new ListPromptsResult() + return new ListPromptsResult() { Prompts = [ new Prompt() @@ -306,9 +310,9 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } } ] - }); + }; }, - GetPromptHandler = (request, cancellationToken) => + GetPromptHandler = async (request, cancellationToken) => { if (request.Params is null) { @@ -365,10 +369,10 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st throw new McpException($"Unknown prompt: {request.Params.Name}"); } - return Task.FromResult(new GetPromptResult() + return new GetPromptResult() { Messages = messages - }); + }; } }, }; diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 3e5d56805..46c1879cd 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -381,7 +381,7 @@ public async Task AsClientLoggerProvider_MessagesSentToClient() (notification, cancellationToken) => { Assert.True(channel.Writer.TryWrite(JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions))); - return Task.CompletedTask; + return default; })) { logger.LogTrace("Trace {Message}", "message"); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index ed6d35632..daee177c1 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -41,18 +41,19 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) { Sampling = new SamplingCapability { - SamplingHandler = (c, p, t) => Task.FromResult( - new CreateMessageResult { + SamplingHandler = async (c, p, t) => + new CreateMessageResult + { Content = new Content { Text = "result" }, Model = "test-model", Role = Role.User, StopReason = "endTurn" - }), + }, }, Roots = new RootsCapability { ListChanged = true, - RootsHandler = (t, r) => Task.FromResult(new ListRootsResult { Roots = [] }), + RootsHandler = async (t, r) => new ListRootsResult { Roots = [] }, } } }; diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 746ab78dc..d6944df97 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -265,7 +265,7 @@ public async Task SubscribeResource_Stdio() { var notificationParams = JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions); tcs.TrySetResult(true); - return Task.CompletedTask; + return default; }) ] } @@ -295,7 +295,7 @@ public async Task UnsubscribeResource_Stdio() { var notificationParams = JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions); receivedNotification.TrySetResult(true); - return Task.CompletedTask; + return default; }) ] } @@ -370,10 +370,10 @@ public async Task Sampling_Stdio(string clientId) { Sampling = new() { - SamplingHandler = (_, _, _) => + SamplingHandler = async (_, _, _) => { samplingHandlerCalls++; - return Task.FromResult(new CreateMessageResult + return new CreateMessageResult { Model = "test-model", Role = Role.Assistant, @@ -382,7 +382,7 @@ public async Task Sampling_Stdio(string clientId) Type = "text", Text = "Test response" } - }); + }; }, }, }, @@ -569,7 +569,7 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) { receivedNotification.TrySetResult(true); } - return Task.CompletedTask; + return default; }) ] } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs index 896765d97..f304f8309 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs @@ -21,7 +21,7 @@ public McpServerBuilderExtensionsHandlerTests() [Fact] public void WithListToolsHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new ListToolsResult()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new ListToolsResult(); _builder.Object.WithListToolsHandler(handler); @@ -34,7 +34,7 @@ public void WithListToolsHandler_Sets_Handler() [Fact] public void WithCallToolHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new CallToolResponse()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new CallToolResponse(); _builder.Object.WithCallToolHandler(handler); @@ -47,7 +47,7 @@ public void WithCallToolHandler_Sets_Handler() [Fact] public void WithListPromptsHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new ListPromptsResult()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new ListPromptsResult(); _builder.Object.WithListPromptsHandler(handler); @@ -60,7 +60,7 @@ public void WithListPromptsHandler_Sets_Handler() [Fact] public void WithGetPromptHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new GetPromptResult()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new GetPromptResult(); _builder.Object.WithGetPromptHandler(handler); @@ -73,7 +73,7 @@ public void WithGetPromptHandler_Sets_Handler() [Fact] public void WithListResourceTemplatesHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new ListResourceTemplatesResult()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new ListResourceTemplatesResult(); _builder.Object.WithListResourceTemplatesHandler(handler); @@ -86,7 +86,7 @@ public void WithListResourceTemplatesHandler_Sets_Handler() [Fact] public void WithListResourcesHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new ListResourcesResult()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new ListResourcesResult(); _builder.Object.WithListResourcesHandler(handler); @@ -99,7 +99,7 @@ public void WithListResourcesHandler_Sets_Handler() [Fact] public void WithReadResourceHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new ReadResourceResult()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new ReadResourceResult(); _builder.Object.WithReadResourceHandler(handler); @@ -112,7 +112,7 @@ public void WithReadResourceHandler_Sets_Handler() [Fact] public void WithCompleteHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new CompleteResult()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new CompleteResult(); _builder.Object.WithCompleteHandler(handler); @@ -125,7 +125,7 @@ public void WithCompleteHandler_Sets_Handler() [Fact] public void WithSubscribeToResourcesHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new EmptyResult()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new EmptyResult(); _builder.Object.WithSubscribeToResourcesHandler(handler); @@ -138,7 +138,7 @@ public void WithSubscribeToResourcesHandler_Sets_Handler() [Fact] public void WithUnsubscribeFromResourcesHandler_Sets_Handler() { - Func, CancellationToken, Task> handler = (context, token) => Task.FromResult(new EmptyResult()); + Func, CancellationToken, ValueTask> handler = async (context, token) => new EmptyResult(); _builder.Object.WithUnsubscribeFromResourcesHandler(handler); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index b6a980f90..289330aea 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -9,8 +9,6 @@ using System.Text.Json.Serialization; using System.Threading.Channels; -#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - namespace ModelContextProtocol.Tests.Configuration; public partial class McpServerBuilderExtensionsPromptsTests : ClientServerTestBase @@ -144,7 +142,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() await using (client.RegisterNotificationHandler("notifications/prompts/list_changed", (notification, cancellationToken) => { listChanged.Writer.TryWrite(notification); - return Task.CompletedTask; + return default; })) { serverPrompts.Add(newPrompt); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 3e51a8f4a..c8a407e8e 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -16,8 +16,6 @@ using System.Text.RegularExpressions; using System.Threading.Channels; -#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - namespace ModelContextProtocol.Tests.Configuration; public partial class McpServerBuilderExtensionsToolsTests : ClientServerTestBase @@ -209,7 +207,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await using (client.RegisterNotificationHandler(NotificationMethods.ToolListChangedNotification, (notification, cancellationToken) => { listChanged.Writer.TryWrite(notification); - return Task.CompletedTask; + return default; })) { serverTools.Add(newTool); @@ -588,7 +586,7 @@ public async Task HandlesIProgressParameter() { ProgressNotification pn = JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions)!; notifications.Enqueue(pn); - return Task.CompletedTask; + return default; })) { var result = await client.SendRequestAsync( diff --git a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs index 149c7d639..d0c59603d 100644 --- a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs +++ b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs @@ -77,10 +77,10 @@ public async Task Sampling_Sse_EverythingServer() { Sampling = new() { - SamplingHandler = (_, _, _) => + SamplingHandler = async (_, _, _) => { samplingHandlerCalls++; - return Task.FromResult(new CreateMessageResult + return new CreateMessageResult { Model = "test-model", Role = Role.Assistant, @@ -89,7 +89,7 @@ public async Task Sampling_Sse_EverythingServer() Type = "text", Text = "Test response" } - }); + }; }, }, }, diff --git a/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs index c90e78da2..81a0f87df 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs @@ -39,7 +39,7 @@ public async Task PrecancelRequest_CancelsBeforeSending() await using (Server.RegisterNotificationHandler(NotificationMethods.CancelledNotification, (notification, cancellationToken) => { gotCancellation = true; - return Task.CompletedTask; + return default; })) { await Assert.ThrowsAsync(() => client.ListToolsAsync(cancellationToken: new CancellationToken(true))); diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index 9ebd1e387..6aa7ccf02 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -25,7 +25,7 @@ public async Task RegistrationsAreRemovedWhenDisposed() { Interlocked.Increment(ref counter); tcs.SetResult(true); - return Task.CompletedTask; + return default; })) { await Server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); @@ -58,7 +58,7 @@ public async Task MultipleRegistrationsResultInMultipleCallbacks() tcs.TrySetResult(true); } - return Task.CompletedTask; + return default; }); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs index e897ff4d6..31176bc8c 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs @@ -21,16 +21,16 @@ public void AllPropertiesAreSettable() Assert.Null(handlers.SubscribeToResourcesHandler); Assert.Null(handlers.UnsubscribeFromResourcesHandler); - handlers.ListToolsHandler = (p, c) => Task.FromResult(new ListToolsResult()); - handlers.CallToolHandler = (p, c) => Task.FromResult(new CallToolResponse()); - handlers.ListPromptsHandler = (p, c) => Task.FromResult(new ListPromptsResult()); - handlers.GetPromptHandler = (p, c) => Task.FromResult(new GetPromptResult()); - handlers.ListResourceTemplatesHandler = (p, c) => Task.FromResult(new ListResourceTemplatesResult()); - handlers.ListResourcesHandler = (p, c) => Task.FromResult(new ListResourcesResult()); - handlers.ReadResourceHandler = (p, c) => Task.FromResult(new ReadResourceResult()); - handlers.CompleteHandler = (p, c) => Task.FromResult(new CompleteResult()); - handlers.SubscribeToResourcesHandler = (s, c) => Task.FromResult(new EmptyResult()); - handlers.UnsubscribeFromResourcesHandler = (s, c) => Task.FromResult(new EmptyResult()); + handlers.ListToolsHandler = async (p, c) => new ListToolsResult(); + handlers.CallToolHandler = async (p, c) => new CallToolResponse(); + handlers.ListPromptsHandler = async (p, c) => new ListPromptsResult(); + handlers.GetPromptHandler = async (p, c) => new GetPromptResult(); + handlers.ListResourceTemplatesHandler = async (p, c) => new ListResourceTemplatesResult(); + handlers.ListResourcesHandler = async (p, c) => new ListResourcesResult(); + handlers.ReadResourceHandler = async (p, c) => new ReadResourceResult(); + handlers.CompleteHandler = async (p, c) => new CompleteResult(); + handlers.SubscribeToResourcesHandler = async (s, c) => new EmptyResult(); + handlers.UnsubscribeFromResourcesHandler = async (s, c) => new EmptyResult(); Assert.NotNull(handlers.ListToolsHandler); Assert.NotNull(handlers.CallToolHandler); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs index a97fe6f3f..f8ffc9521 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs @@ -12,10 +12,7 @@ public void CanCreateServerWithLoggingLevelHandler() services.AddMcpServer() .WithStdioServerTransport() - .WithSetLoggingLevelHandler((ctx, ct) => - { - return Task.FromResult(new EmptyResult()); - }); + .WithSetLoggingLevelHandler(async (ctx, ct) => new EmptyResult()); var provider = services.BuildServiceProvider(); @@ -29,10 +26,7 @@ public void AddingLoggingLevelHandlerSetsLoggingCapability() services.AddMcpServer() .WithStdioServerTransport() - .WithSetLoggingLevelHandler((ctx, ct) => - { - return Task.FromResult(new EmptyResult()); - }); + .WithSetLoggingLevelHandler(async (ctx, ct) => new EmptyResult()); var provider = services.BuildServiceProvider(); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index 1c7e66134..6182ddc07 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -13,19 +13,19 @@ public void CanCreateServerWithResourceTemplates() services.AddMcpServer() .WithStdioServerTransport() - .WithListResourceTemplatesHandler((ctx, ct) => + .WithListResourceTemplatesHandler(async (ctx, ct) => { - return Task.FromResult(new ListResourceTemplatesResult + return new ListResourceTemplatesResult { ResourceTemplates = [ new ResourceTemplate { Name = "Static Resource", Description = "A static resource with a numeric ID", UriTemplate = "test://static/resource/{id}" } ] - }); + }; }) - .WithReadResourceHandler((ctx, ct) => + .WithReadResourceHandler(async (ctx, ct) => { - return Task.FromResult(new ReadResourceResult + return new ReadResourceResult { Contents = [new TextResourceContents { @@ -33,7 +33,7 @@ public void CanCreateServerWithResourceTemplates() Text = "Static Resource", MimeType = "text/plain", }] - }); + }; }); var provider = services.BuildServiceProvider(); @@ -48,19 +48,19 @@ public void CanCreateServerWithResources() services.AddMcpServer() .WithStdioServerTransport() - .WithListResourcesHandler((ctx, ct) => + .WithListResourcesHandler(async (ctx, ct) => { - return Task.FromResult(new ListResourcesResult + return new ListResourcesResult { Resources = [ new Resource { Name = "Static Resource", Description = "A static resource with a numeric ID", Uri = "test://static/resource/foo.txt" } ] - }); + }; }) - .WithReadResourceHandler((ctx, ct) => + .WithReadResourceHandler(async (ctx, ct) => { - return Task.FromResult(new ReadResourceResult + return new ReadResourceResult { Contents = [new TextResourceContents { @@ -68,7 +68,7 @@ public void CanCreateServerWithResources() Text = "Static Resource", MimeType = "text/plain", }] - }); + }; }); var provider = services.BuildServiceProvider(); @@ -82,9 +82,9 @@ public void CreatingReadHandlerWithNoListHandlerFails() var services = new ServiceCollection(); services.AddMcpServer() .WithStdioServerTransport() - .WithReadResourceHandler((ctx, ct) => + .WithReadResourceHandler(async (ctx, ct) => { - return Task.FromResult(new ReadResourceResult + return new ReadResourceResult { Contents = [new TextResourceContents { @@ -92,7 +92,7 @@ public void CreatingReadHandlerWithNoListHandlerFails() Text = "Static Resource", MimeType = "text/plain", }] - }); + }; }); var sp = services.BuildServiceProvider(); Assert.Throws(() => sp.GetRequiredService()); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 5dd572ae0..34ac9c8a0 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -202,8 +202,8 @@ await Can_Handle_Requests( { Completions = new() { - CompleteHandler = (request, ct) => - Task.FromResult(new CompleteResult + CompleteHandler = async (request, ct) => + new CompleteResult { Completion = new() { @@ -211,7 +211,7 @@ await Can_Handle_Requests( Total = 2, HasMore = true } - }) + } } }, method: RequestMethods.CompletionComplete, @@ -234,19 +234,19 @@ await Can_Handle_Requests( { Resources = new() { - ListResourceTemplatesHandler = (request, ct) => + ListResourceTemplatesHandler = async (request, ct) => { - return Task.FromResult(new ListResourceTemplatesResult + return new ListResourceTemplatesResult { ResourceTemplates = [new() { UriTemplate = "test", Name = "Test Resource" }] - }); + }; }, - ListResourcesHandler = (request, ct) => + ListResourcesHandler = async (request, ct) => { - return Task.FromResult(new ListResourcesResult + return new ListResourcesResult { Resources = [new() { Uri = "test", Name = "Test Resource" }] - }); + }; }, ReadResourceHandler = (request, ct) => throw new NotImplementedException(), } @@ -270,12 +270,12 @@ await Can_Handle_Requests( { Resources = new() { - ListResourcesHandler = (request, ct) => + ListResourcesHandler = async (request, ct) => { - return Task.FromResult(new ListResourcesResult + return new ListResourcesResult { Resources = [new() { Uri = "test", Name = "Test Resource" }] - }); + }; }, ReadResourceHandler = (request, ct) => throw new NotImplementedException(), } @@ -305,12 +305,12 @@ await Can_Handle_Requests( { Resources = new() { - ReadResourceHandler = (request, ct) => + ReadResourceHandler = async (request, ct) => { - return Task.FromResult(new ReadResourceResult + return new ReadResourceResult { Contents = [new TextResourceContents { Text = "test" }] - }); + }; }, ListResourcesHandler = (request, ct) => throw new NotImplementedException(), } @@ -342,12 +342,12 @@ await Can_Handle_Requests( { Prompts = new() { - ListPromptsHandler = (request, ct) => + ListPromptsHandler = async (request, ct) => { - return Task.FromResult(new ListPromptsResult + return new ListPromptsResult { Prompts = [new() { Name = "test" }] - }); + }; }, GetPromptHandler = (request, ct) => throw new NotImplementedException(), }, @@ -377,7 +377,7 @@ await Can_Handle_Requests( { Prompts = new() { - GetPromptHandler = (request, ct) => Task.FromResult(new GetPromptResult { Description = "test" }), + GetPromptHandler = async (request, ct) => new GetPromptResult { Description = "test" }, ListPromptsHandler = (request, ct) => throw new NotImplementedException(), } }, @@ -405,12 +405,12 @@ await Can_Handle_Requests( { Tools = new() { - ListToolsHandler = (request, ct) => + ListToolsHandler = async (request, ct) => { - return Task.FromResult(new ListToolsResult + return new ListToolsResult { Tools = [new() { Name = "test" }] - }); + }; }, CallToolHandler = (request, ct) => throw new NotImplementedException(), } @@ -440,12 +440,12 @@ await Can_Handle_Requests( { Tools = new() { - CallToolHandler = (request, ct) => + CallToolHandler = async (request, ct) => { - return Task.FromResult(new CallToolResponse + return new CallToolResponse { Content = [new Content { Text = "test" }] - }); + }; }, ListToolsHandler = (request, ct) => throw new NotImplementedException(), } @@ -623,7 +623,7 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella throw new NotImplementedException(); public Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => throw new NotImplementedException(); } @@ -639,7 +639,7 @@ public async Task NotifyProgress_Should_Be_Handled() NotificationHandlers = [new(NotificationMethods.ProgressNotification, (notification, cancellationToken) => { notificationReceived.TrySetResult(notification); - return Task.CompletedTask; + return default; })], }; From 812e5f7f1c46694e10731e4ce2954c19bdca92b0 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 11 Apr 2025 17:13:37 -0400 Subject: [PATCH 3/3] Address feedback --- .../SingleSessionMcpServerHostedService.cs | 2 +- .../Transport/StreamServerTransport.cs | 2 +- src/ModelContextProtocol/Server/McpServer.cs | 25 ++++++---- .../McpServerBuilderExtensionsToolsTests.cs | 45 ++++++++++++++---- .../Configuration/McpServerScopedTests.cs | 47 ++++++++++++++++--- .../ModelContextProtocol.Tests/JsonContext.cs | 20 -------- 6 files changed, 97 insertions(+), 44 deletions(-) delete mode 100644 tests/ModelContextProtocol.Tests/JsonContext.cs diff --git a/src/ModelContextProtocol/Hosting/SingleSessionMcpServerHostedService.cs b/src/ModelContextProtocol/Hosting/SingleSessionMcpServerHostedService.cs index 60316bcf5..70791273e 100644 --- a/src/ModelContextProtocol/Hosting/SingleSessionMcpServerHostedService.cs +++ b/src/ModelContextProtocol/Hosting/SingleSessionMcpServerHostedService.cs @@ -17,7 +17,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { try { - await session.RunAsync(stoppingToken); + await session.RunAsync(stoppingToken).ConfigureAwait(false); } finally { diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs index f08f0db3e..10a786e86 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs @@ -82,7 +82,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), cancellationToken).ConfigureAwait(false); await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false); - await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);; + await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false); _logger.TransportSentMessage(_endpointName, id); } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 9b0559ee2..079bcb9b1 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -502,15 +502,24 @@ async ValueTask InvokeScopedAsync( TParams? args, CancellationToken cancellationToken) { - using var scope = Services?.GetService()?.CreateScope(); - - return await handler( - new RequestContext(this) + var scope = Services?.GetService()?.CreateAsyncScope(); + try + { + return await handler( + new RequestContext(this) + { + Services = scope?.ServiceProvider ?? Services, + Params = args + }, + cancellationToken).ConfigureAwait(false); + } + finally + { + if (scope is not null) { - Services = scope?.ServiceProvider ?? Services, - Params = args - }, - cancellationToken).ConfigureAwait(false); + await scope.Value.DisposeAsync().ConfigureAwait(false); + } + } } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index c8a407e8e..d833356a6 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -107,7 +107,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer throw new Exception($"Unknown tool '{request.Params?.Name}'"); } }) - .WithTools(serializerOptions: JsonContext.Default.Options); + .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options); services.AddSingleton(new ObjectWithId()); } @@ -429,8 +429,13 @@ public void Empty_Enumerables_Is_Allowed() [Fact] public void Register_Tools_From_Current_Assembly() { + if (!JsonSerializer.IsReflectionEnabledByDefault) + { + return; + } + ServiceCollection sc = new(); - sc.AddMcpServer().WithToolsFromAssembly(serializerOptions: JsonContext.Default.Options); + sc.AddMcpServer().WithToolsFromAssembly(); IServiceProvider services = sc.BuildServiceProvider(); Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "Echo"); @@ -446,7 +451,7 @@ public void WithTools_Parameters_Satisfiable_From_DI(bool parameterInServices) { sc.AddSingleton(new ComplexObject()); } - sc.AddMcpServer().WithTools([typeof(EchoTool)], JsonContext.Default.Options); + sc.AddMcpServer().WithTools([typeof(EchoTool)], BuilderToolsJsonContext.Default.Options); IServiceProvider services = sc.BuildServiceProvider(); McpServerTool tool = services.GetServices().First(t => t.ProtocolTool.Name == "EchoComplex"); @@ -460,6 +465,7 @@ public void WithTools_Parameters_Satisfiable_From_DI(bool parameterInServices) } } + [Theory] [InlineData(ServiceLifetime.Singleton)] [InlineData(ServiceLifetime.Scoped)] @@ -467,6 +473,11 @@ public void WithTools_Parameters_Satisfiable_From_DI(bool parameterInServices) [InlineData(null)] public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime? lifetime) { + if (!JsonSerializer.IsReflectionEnabledByDefault) + { + return; + } + ServiceCollection sc = new(); switch (lifetime) { @@ -483,7 +494,7 @@ public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime break; } - sc.AddMcpServer().WithToolsFromAssembly(serializerOptions: JsonContext.Default.Options); + sc.AddMcpServer().WithToolsFromAssembly(); IServiceProvider services = sc.BuildServiceProvider(); McpServerTool tool = services.GetServices().First(t => t.ProtocolTool.Name == "EchoComplex"); @@ -526,9 +537,9 @@ public void Register_Tools_From_Multiple_Sources() { ServiceCollection sc = new(); sc.AddMcpServer() - .WithTools(serializerOptions: JsonContext.Default.Options) - .WithTools(serializerOptions: JsonContext.Default.Options) - .WithTools([typeof(ToolTypeWithNoAttribute)], JsonContext.Default.Options); + .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options) + .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options) + .WithTools([typeof(ToolTypeWithNoAttribute)], BuilderToolsJsonContext.Default.Options); IServiceProvider services = sc.BuildServiceProvider(); Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "double_echo"); @@ -759,4 +770,22 @@ public class ObjectWithId { public string Id { get; set; } = Guid.NewGuid().ToString("N"); } -} + + public class ComplexObject + { + public string? Name { get; set; } + public int Age { get; set; } + } + + [JsonSerializable(typeof(bool))] + [JsonSerializable(typeof(int))] + [JsonSerializable(typeof(long))] + [JsonSerializable(typeof(double))] + [JsonSerializable(typeof(string))] + [JsonSerializable(typeof(DateTime))] + [JsonSerializable(typeof(DateTimeOffset))] + [JsonSerializable(typeof(ComplexObject))] + [JsonSerializable(typeof(string[]))] + [JsonSerializable(typeof(JsonElement))] + partial class BuilderToolsJsonContext : JsonSerializerContext; +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs index 93573693e..86aa16abb 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs @@ -1,8 +1,9 @@ -using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Client; using ModelContextProtocol.Server; +using ModelContextProtocol.Utils.Json; using System.Text.Json; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Tests.Configuration; @@ -15,7 +16,7 @@ public McpServerScopedTests(ITestOutputHelper testOutputHelper) protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) { - mcpServerBuilder.WithTools(serializerOptions: JsonContext.Default.Options); + mcpServerBuilder.WithTools(serializerOptions: McpServerScopedTestsJsonContext.Default.Options); services.AddScoped(_ => new ComplexObject() { Name = "Scoped" }); } @@ -24,11 +25,20 @@ public async Task InjectScopedServiceAsArgument() { IMcpClient client = await CreateMcpClientForServer(); - var tools = await client.ListToolsAsync(JsonContext.Default.Options, TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(McpServerScopedTestsJsonContext.Default.Options, TestContext.Current.CancellationToken); var tool = tools.First(t => t.Name == nameof(EchoTool.EchoComplex)); - Assert.DoesNotContain("\"complex\"", JsonSerializer.Serialize(tool.JsonSchema, AIJsonUtilities.DefaultOptions)); + Assert.DoesNotContain("\"complex\"", JsonSerializer.Serialize(tool.JsonSchema, McpJsonUtilities.DefaultOptions)); - Assert.Contains("\"Scoped\"", JsonSerializer.Serialize(await tool.InvokeAsync(cancellationToken: TestContext.Current.CancellationToken), AIJsonUtilities.DefaultOptions)); + int startingConstructed = ComplexObject.Constructed; + int startingDisposed = ComplexObject.Disposed; + + for (int i = 1; i <= 10; i++) + { + Assert.Contains("\"Scoped\"", JsonSerializer.Serialize(await tool.InvokeAsync(cancellationToken: TestContext.Current.CancellationToken), McpJsonUtilities.DefaultOptions)); + + Assert.Equal(startingConstructed + i, ComplexObject.Constructed); + Assert.Equal(startingDisposed + i, ComplexObject.Disposed); + } } [McpServerToolType] @@ -37,4 +47,29 @@ public sealed class EchoTool() [McpServerTool] public static string EchoComplex(ComplexObject complex) => complex.Name!; } + + public class ComplexObject : IAsyncDisposable + { + public static int Constructed; + public static int Disposed; + + public ComplexObject() + { + Interlocked.Increment(ref Constructed); + } + + public ValueTask DisposeAsync() + { + Interlocked.Increment(ref Disposed); + return default; + } + + public string? Name { get; set; } + public int Age { get; set; } + } + + [JsonSerializable(typeof(string))] + [JsonSerializable(typeof(ComplexObject))] + [JsonSerializable(typeof(JsonElement))] + partial class McpServerScopedTestsJsonContext : JsonSerializerContext; } diff --git a/tests/ModelContextProtocol.Tests/JsonContext.cs b/tests/ModelContextProtocol.Tests/JsonContext.cs deleted file mode 100644 index 4714380bf..000000000 --- a/tests/ModelContextProtocol.Tests/JsonContext.cs +++ /dev/null @@ -1,20 +0,0 @@ -using System.Text.Json; -using System.Text.Json.Serialization; - -public class ComplexObject -{ - public string? Name { get; set; } - public int Age { get; set; } -} - -[JsonSerializable(typeof(bool))] -[JsonSerializable(typeof(int))] -[JsonSerializable(typeof(long))] -[JsonSerializable(typeof(double))] -[JsonSerializable(typeof(string))] -[JsonSerializable(typeof(DateTime))] -[JsonSerializable(typeof(DateTimeOffset))] -[JsonSerializable(typeof(ComplexObject))] -[JsonSerializable(typeof(string[]))] -[JsonSerializable(typeof(JsonElement))] -partial class JsonContext : JsonSerializerContext; \ No newline at end of file