From 9d2ce788fa19e3ec1bf59f42246fb3857b62b2dc Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sun, 29 Jun 2025 21:35:10 -0400 Subject: [PATCH] Enable injecting IMcpServer and friends into ctors --- .../Server/AIFunctionMcpServerPrompt.cs | 75 +++--------------- .../Server/AIFunctionMcpServerResource.cs | 72 +++--------------- .../Server/AIFunctionMcpServerTool.cs | 76 +++---------------- .../Server/AugmentedServiceProvider.cs | 58 ++++++++++++++ .../Server/McpServerPromptTests.cs | 54 +++++++++++++ .../Server/McpServerResourceTests.cs | 53 +++++++++++++ .../Server/McpServerToolTests.cs | 52 +++++++++++++ 7 files changed, 249 insertions(+), 191 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs index 8d446c58c..ef463c374 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; using System.ComponentModel; +using System.Diagnostics; using System.Reflection; using System.Text.Json; @@ -57,8 +58,8 @@ internal sealed class AIFunctionMcpServerPrompt : McpServerPrompt return Create( AIFunctionFactory.Create(method, args => { - var request = (RequestContext)args.Context![typeof(RequestContext)]!; - return createTargetFunc(request); + Debug.Assert(args.Services is RequestServiceProvider, $"The service provider should be a {nameof(RequestServiceProvider)} for this method to work correctly."); + return createTargetFunc(((RequestServiceProvider)args.Services!).Request); }, CreateAIFunctionFactoryOptions(method, options)), options); } @@ -74,54 +75,15 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( JsonSchemaCreateOptions = options?.SchemaCreateOptions, ConfigureParameterBinding = pi => { - if (pi.ParameterType == typeof(RequestContext)) + if (RequestServiceProvider.IsAugmentedWith(pi.ParameterType) || + (options?.Services?.GetService() is { } ispis && + ispis.IsService(pi.ParameterType))) { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args), - }; - } - - if (pi.ParameterType == typeof(IMcpServer)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args)?.Server, - }; - } - - if (pi.ParameterType == typeof(IProgress)) - { - // Bind IProgress to the progress token in the request, - // if there is one. If we can't get one, return a nop progress. return new() { ExcludeFromSchema = true, BindParameter = (pi, args) => - { - var requestContent = GetRequestContext(args); - if (requestContent?.Server is { } server && - requestContent?.Params?.ProgressToken is { } progressToken) - { - return new TokenProgress(server, progressToken); - } - - return NullProgress.Instance; - }, - }; - } - - if (options?.Services is { } services && - services.GetService() is { } ispis && - ispis.IsService(pi.ParameterType)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + args.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -133,24 +95,13 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; } return default; - - static RequestContext? GetRequestContext(AIFunctionArguments args) - { - if (args.Context?.TryGetValue(typeof(RequestContext), out var orc) is true && - orc is RequestContext requestContext) - { - return requestContext; - } - - return null; - } }, }; @@ -226,14 +177,10 @@ public override async ValueTask GetAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - AIFunctionArguments arguments = new() - { - Services = request.Services, - Context = new Dictionary() { [typeof(RequestContext)] = request } - }; + request.Services = new RequestServiceProvider(request, request.Services); + AIFunctionArguments arguments = new() { Services = request.Services }; - var argDict = request.Params?.Arguments; - if (argDict is not null) + if (request.Params?.Arguments is { } argDict) { foreach (var kvp in argDict) { diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs index c44339b14..5412c4aca 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Protocol; using System.Collections.Concurrent; using System.ComponentModel; +using System.Diagnostics; using System.Globalization; using System.Reflection; using System.Text; @@ -64,8 +65,8 @@ internal sealed class AIFunctionMcpServerResource : McpServerResource return Create( AIFunctionFactory.Create(method, args => { - var request = (RequestContext)args.Context![typeof(RequestContext)]!; - return createTargetFunc(request); + Debug.Assert(args.Services is RequestServiceProvider, $"The service provider should be a {nameof(RequestServiceProvider)} for this method to work correctly."); + return createTargetFunc(((RequestServiceProvider)args.Services!).Request); }, CreateAIFunctionFactoryOptions(method, options)), options); } @@ -81,54 +82,15 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( JsonSchemaCreateOptions = options?.SchemaCreateOptions, ConfigureParameterBinding = pi => { - if (pi.ParameterType == typeof(RequestContext)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args), - }; - } - - if (pi.ParameterType == typeof(IMcpServer)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args)?.Server, - }; - } - - if (pi.ParameterType == typeof(IProgress)) - { - // Bind IProgress to the progress token in the request, - // if there is one. If we can't get one, return a nop progress. - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - { - var requestContent = GetRequestContext(args); - if (requestContent?.Server is { } server && - requestContent?.Params?.ProgressToken is { } progressToken) - { - return new TokenProgress(server, progressToken); - } - - return NullProgress.Instance; - }, - }; - } - - if (options?.Services is { } services && - services.GetService() is { } ispis && - ispis.IsService(pi.ParameterType)) + if (RequestServiceProvider.IsAugmentedWith(pi.ParameterType) || + (options?.Services?.GetService() is { } ispis && + ispis.IsService(pi.ParameterType))) { return new() { ExcludeFromSchema = true, BindParameter = (pi, args) => - GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + args.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -140,7 +102,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -172,17 +134,6 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( } return default; - - static RequestContext? GetRequestContext(AIFunctionArguments args) - { - if (args.Context?.TryGetValue(typeof(RequestContext), out var rc) is true && - rc is RequestContext requestContext) - { - return requestContext; - } - - return null; - } }, }; @@ -365,11 +316,8 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour } // Build up the arguments for the AIFunction call, including all of the name/value pairs from the URI. - AIFunctionArguments arguments = new() - { - Services = request.Services, - Context = new Dictionary() { [typeof(RequestContext)] = request } - }; + request.Services = new RequestServiceProvider(request, request.Services); + AIFunctionArguments arguments = new() { Services = request.Services }; // For templates, populate the arguments from the URI template. if (match is not null) diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index 725aaf3ec..596871f76 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -4,7 +4,7 @@ using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.ComponentModel; -using System.Diagnostics.CodeAnalysis; +using System.Diagnostics; using System.Reflection; using System.Text.Json; using System.Text.Json.Nodes; @@ -64,8 +64,8 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool return Create( AIFunctionFactory.Create(method, args => { - var request = (RequestContext)args.Context![typeof(RequestContext)]!; - return createTargetFunc(request); + Debug.Assert(args.Services is RequestServiceProvider, $"The service provider should be a {nameof(RequestServiceProvider)} for this method to work correctly."); + return createTargetFunc(((RequestServiceProvider)args.Services!).Request); }, CreateAIFunctionFactoryOptions(method, options)), options); } @@ -81,54 +81,15 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( JsonSchemaCreateOptions = options?.SchemaCreateOptions, ConfigureParameterBinding = pi => { - if (pi.ParameterType == typeof(RequestContext)) + if (RequestServiceProvider.IsAugmentedWith(pi.ParameterType) || + (options?.Services?.GetService() is { } ispis && + ispis.IsService(pi.ParameterType))) { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args), - }; - } - - if (pi.ParameterType == typeof(IMcpServer)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args)?.Server, - }; - } - - if (pi.ParameterType == typeof(IProgress)) - { - // Bind IProgress to the progress token in the request, - // if there is one. If we can't get one, return a nop progress. return new() { ExcludeFromSchema = true, BindParameter = (pi, args) => - { - var requestContent = GetRequestContext(args); - if (requestContent?.Server is { } server && - requestContent?.Params?.ProgressToken is { } progressToken) - { - return new TokenProgress(server, progressToken); - } - - return NullProgress.Instance; - }, - }; - } - - if (options?.Services is { } services && - services.GetService() is { } ispis && - ispis.IsService(pi.ParameterType)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + args.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -140,24 +101,13 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; } return default; - - static RequestContext? GetRequestContext(AIFunctionArguments args) - { - if (args.Context?.TryGetValue(typeof(RequestContext), out var orc) is true && - orc is RequestContext requestContext) - { - return requestContext; - } - - return null; - } }, }; @@ -260,14 +210,10 @@ public override async ValueTask InvokeAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - AIFunctionArguments arguments = new() - { - Services = request.Services, - Context = new Dictionary() { [typeof(RequestContext)] = request } - }; + request.Services = new RequestServiceProvider(request, request.Services); + AIFunctionArguments arguments = new() { Services = request.Services }; - var argDict = request.Params?.Arguments; - if (argDict is not null) + if (request.Params?.Arguments is { } argDict) { foreach (var kvp in argDict) { diff --git a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs b/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs new file mode 100644 index 000000000..3372072fe --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs @@ -0,0 +1,58 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// Augments a service provider with additional request-related services. +internal sealed class RequestServiceProvider( + RequestContext request, IServiceProvider? innerServices) : + IServiceProvider, IKeyedServiceProvider, + IServiceProviderIsService, IServiceProviderIsKeyedService, + IDisposable, IAsyncDisposable + where TRequestParams : RequestParams +{ + /// Gets the request associated with this instance. + public RequestContext Request => request; + + /// Gets whether the specified type is in the list of additional types this service provider wraps around the one in a provided request's services. + public static bool IsAugmentedWith(Type serviceType) => + serviceType == typeof(RequestContext) || + serviceType == typeof(IMcpServer) || + serviceType == typeof(IProgress); + + /// + public object? GetService(Type serviceType) => + serviceType == typeof(RequestContext) ? request : + serviceType == typeof(IMcpServer) ? request.Server : + serviceType == typeof(IProgress) ? + (request.Params?.ProgressToken is { } progressToken ? new TokenProgress(request.Server, progressToken) : NullProgress.Instance) : + innerServices?.GetService(serviceType); + + /// + public bool IsService(Type serviceType) => + IsAugmentedWith(serviceType) || + (innerServices as IServiceProviderIsService)?.IsService(serviceType) is true; + + /// + public bool IsKeyedService(Type serviceType, object? serviceKey) => + (serviceKey is null && IsService(serviceType)) || + (innerServices as IServiceProviderIsKeyedService)?.IsKeyedService(serviceType, serviceKey) is true; + + /// + public object? GetKeyedService(Type serviceType, object? serviceKey) => + serviceKey is null ? GetService(serviceType) : + (innerServices as IKeyedServiceProvider)?.GetKeyedService(serviceType, serviceKey); + + /// + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => + GetKeyedService(serviceType, serviceKey) ?? + throw new InvalidOperationException($"No service of type '{serviceType}' with key '{serviceKey}' is registered."); + + /// + public void Dispose() => + (innerServices as IDisposable)?.Dispose(); + + /// + public ValueTask DisposeAsync() => + innerServices is IAsyncDisposable asyncDisposable ? asyncDisposable.DisposeAsync() : default; +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index ca1bfe97b..90998e24b 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -1,9 +1,11 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Primitives; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Moq; using System.ComponentModel; +using System.Diagnostics; using System.Reflection; using System.Text.Json; using System.Text.Json.Nodes; @@ -44,6 +46,58 @@ public async Task SupportsIMcpServer() Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } + [Fact] + public async Task SupportsCtorInjection() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + Mock mockServer = new(); + mockServer.SetupGet(s => s.Services).Returns(services); + + MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestPrompt)); + Assert.NotNull(testMethod); + McpServerPrompt prompt = McpServerPrompt.Create(testMethod, r => + { + Assert.NotNull(r.Services); + return ActivatorUtilities.CreateInstance(r.Services, typeof(HasCtorWithSpecialParameters)); + }, new() { Services = services }); + + var result = await prompt.GetAsync( + new RequestContext(mockServer.Object), + TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.NotNull(result.Messages); + Assert.Single(result.Messages); + Assert.Equal("True True True True", Assert.IsType(result.Messages[0].Content).Text); + } + + private sealed class HasCtorWithSpecialParameters + { + private readonly MyService _ms; + private readonly IMcpServer _server; + private readonly RequestContext _request; + private readonly IProgress _progress; + + public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + { + Assert.NotNull(ms); + Assert.NotNull(server); + Assert.NotNull(request); + Assert.NotNull(progress); + + _ms = ms; + _server = server; + _request = request; + _progress = progress; + } + + public string TestPrompt() => $"{_ms is not null} {_server is not null} {_request is not null} {_progress is not null}"; + } + [Fact] public async Task SupportsServiceFromDI() { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index 94a860201..fb0772d04 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -360,6 +360,59 @@ public async Task SupportsIMcpServer() Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); } + [Fact] + public async Task SupportsCtorInjection() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + Mock mockServer = new(); + mockServer.SetupGet(s => s.Services).Returns(services); + + MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestResource)); + Assert.NotNull(testMethod); + McpServerResource tool = McpServerResource.Create(testMethod, r => + { + Assert.NotNull(r.Services); + return ActivatorUtilities.CreateInstance(r.Services, typeof(HasCtorWithSpecialParameters)); + }, new() { Services = services }); + + var result = await tool.ReadAsync( + new RequestContext(mockServer.Object) { Params = new() { Uri = "https://something" } }, + TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.NotNull(result.Contents); + Assert.Single(result.Contents); + Assert.Equal("True True True True", Assert.IsType(result.Contents[0]).Text); + } + + private sealed class HasCtorWithSpecialParameters + { + private readonly MyService _ms; + private readonly IMcpServer _server; + private readonly RequestContext _request; + private readonly IProgress _progress; + + public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + { + Assert.NotNull(ms); + Assert.NotNull(server); + Assert.NotNull(request); + Assert.NotNull(progress); + + _ms = ms; + _server = server; + _request = request; + _progress = progress; + } + + [McpServerResource(UriTemplate = "https://something")] + public string TestResource() => $"{_ms is not null} {_server is not null} {_request is not null} {_progress is not null}"; + } + [Theory] [InlineData(ServiceLifetime.Singleton)] [InlineData(ServiceLifetime.Scoped)] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 742133413..0f67f2a58 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -50,6 +50,58 @@ public async Task SupportsIMcpServer() Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } + [Fact] + public async Task SupportsCtorInjection() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + Mock mockServer = new(); + mockServer.SetupGet(s => s.Services).Returns(services); + + MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestTool)); + Assert.NotNull(testMethod); + McpServerTool tool = McpServerTool.Create(testMethod, r => + { + Assert.NotNull(r.Services); + return ActivatorUtilities.CreateInstance(r.Services, typeof(HasCtorWithSpecialParameters)); + }, new() { Services = services }); + + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object), + TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.NotNull(result.Content); + Assert.Single(result.Content); + Assert.Equal("True True True True", Assert.IsType(result.Content[0]).Text); + } + + private sealed class HasCtorWithSpecialParameters + { + private readonly MyService _ms; + private readonly IMcpServer _server; + private readonly RequestContext _request; + private readonly IProgress _progress; + + public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + { + Assert.NotNull(ms); + Assert.NotNull(server); + Assert.NotNull(request); + Assert.NotNull(progress); + + _ms = ms; + _server = server; + _request = request; + _progress = progress; + } + + public string TestTool() => $"{_ms is not null} {_server is not null} {_request is not null} {_progress is not null}"; + } + [Theory] [InlineData(ServiceLifetime.Singleton)] [InlineData(ServiceLifetime.Scoped)]