From 926aacdb4bcd1144bbdaa3c2df2364e63c4c728b Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Thu, 22 May 2025 23:07:45 -0500 Subject: [PATCH 1/4] Log tool errors --- .../Server/AIFunctionMcpServerTool.cs | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 366eb23cd..94f46a50f 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -1,5 +1,7 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.ComponentModel; using System.Diagnostics.CodeAnalysis; @@ -11,6 +13,8 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . internal sealed class AIFunctionMcpServerTool : McpServerTool { + private readonly ILogger _logger; + /// /// Creates an instance for a method, specified via a instance. /// @@ -19,7 +23,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool McpServerToolCreateOptions? options) { Throw.IfNull(method); - + options = DeriveOptions(method.Method, options); return Create(method.Method, method.Target, options); @@ -172,7 +176,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { Name = options?.Name ?? function.Name, Description = options?.Description ?? function.Description, - InputSchema = function.JsonSchema, + InputSchema = function.JsonSchema, }; if (options is not null) @@ -194,7 +198,7 @@ options.OpenWorld is not null || } } - return new AIFunctionMcpServerTool(function, tool); + return new AIFunctionMcpServerTool(function, tool, options?.Services); } private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpServerToolCreateOptions? options) @@ -239,10 +243,11 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerTool(AIFunction function, Tool tool) + private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider) { AIFunction = function; ProtocolTool = tool; + _logger = serviceProvider?.GetService()?.CreateLogger() ?? (ILogger)NullLogger.Instance; } /// @@ -277,6 +282,9 @@ public override async ValueTask InvokeAsync( } catch (Exception e) when (e is not OperationCanceledException) { + _logger.LogError(e, "Error invoking AIFunction tool '{ToolName}' with arguments '{Args}'.", + request.Params?.Name, string.Join(",", request.Params?.Arguments?.Keys ?? Array.Empty())); + string errorMessage = e is McpException ? $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : $"An error occurred invoking '{request.Params?.Name}'."; @@ -300,29 +308,29 @@ public override async ValueTask InvokeAsync( { Content = [] }, - + string text => new() { Content = [new() { Text = text, Type = "text" }] }, - + Content content => new() { Content = [content] }, - + IEnumerable texts => new() { Content = [.. texts.Select(x => new Content() { Type = "text", Text = x ?? string.Empty })] }, - + IEnumerable contentItems => ConvertAIContentEnumerableToCallToolResponse(contentItems), - + IEnumerable contents => new() { Content = [.. contents] }, - + CallToolResponse callToolResponse => callToolResponse, _ => new() From a87c47cb18032c56b28d6698759fba6bc5a9bbfc Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Fri, 23 May 2025 15:32:39 -0500 Subject: [PATCH 2/4] add test and implement source generated logging --- .../Server/AIFunctionMcpServerTool.cs | 226 +++++++++--------- .../Server/McpServerToolTests.cs | 41 ++++ 2 files changed, 155 insertions(+), 112 deletions(-) diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 94f46a50f..e6137d8d7 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -11,7 +11,7 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . -internal sealed class AIFunctionMcpServerTool : McpServerTool +internal sealed partial class AIFunctionMcpServerTool : McpServerTool { private readonly ILogger _logger; @@ -19,8 +19,8 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool /// Creates an instance for a method, specified via a instance. /// public static new AIFunctionMcpServerTool Create( - Delegate method, - McpServerToolCreateOptions? options) + Delegate method, + McpServerToolCreateOptions? options) { Throw.IfNull(method); @@ -33,26 +33,26 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool /// Creates an instance for a method, specified via a instance. /// public static new AIFunctionMcpServerTool Create( - MethodInfo method, - object? target, - McpServerToolCreateOptions? options) + MethodInfo method, + object? target, + McpServerToolCreateOptions? options) { Throw.IfNull(method); options = DeriveOptions(method, options); return Create( - AIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, options)), - options); + AIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, options)), + options); } /// /// Creates an instance for a method, specified via a instance. /// public static new AIFunctionMcpServerTool Create( - MethodInfo method, - Func, object> createTargetFunc, - McpServerToolCreateOptions? options) + MethodInfo method, + Func, object> createTargetFunc, + McpServerToolCreateOptions? options) { Throw.IfNull(method); Throw.IfNull(createTargetFunc); @@ -60,112 +60,112 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool options = DeriveOptions(method, options); return Create( - AIFunctionFactory.Create(method, args => - { - var request = (RequestContext)args.Context![typeof(RequestContext)]!; - return createTargetFunc(request); - }, CreateAIFunctionFactoryOptions(method, options)), - options); + AIFunctionFactory.Create(method, args => + { + var request = (RequestContext)args.Context![typeof(RequestContext)]!; + return createTargetFunc(request); + }, CreateAIFunctionFactoryOptions(method, options)), + options); } // TODO: Fix the need for this suppression. [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2111:ReflectionToDynamicallyAccessedMembers", - Justification = "AIFunctionFactory ensures that the Type passed to AIFunctionFactoryOptions.CreateInstance has public constructors preserved")] + Justification = "AIFunctionFactory ensures that the Type passed to AIFunctionFactoryOptions.CreateInstance has public constructors preserved")] internal static Func GetCreateInstanceFunc() => - static ([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] type, args) => args.Services is { } services ? - ActivatorUtilities.CreateInstance(services, type) : - Activator.CreateInstance(type)!; + static ([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] type, args) => args.Services is { } services ? + ActivatorUtilities.CreateInstance(services, type) : + Activator.CreateInstance(type)!; private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( - MethodInfo method, McpServerToolCreateOptions? options) => - new() + MethodInfo method, McpServerToolCreateOptions? options) => + new() + { + Name = options?.Name ?? method.GetCustomAttribute()?.Name, + Description = options?.Description, + MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), + SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions, + ConfigureParameterBinding = pi => + { + if (pi.ParameterType == typeof(RequestContext)) { - Name = options?.Name ?? method.GetCustomAttribute()?.Name, - Description = options?.Description, - MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), - SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions, - ConfigureParameterBinding = pi => + return new() { - if (pi.ParameterType == typeof(RequestContext)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args), - }; - } + ExcludeFromSchema = true, + BindParameter = (pi, args) => GetRequestContext(args), + }; + } - if (pi.ParameterType == typeof(IMcpServer)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args)?.Server, - }; - } + if (pi.ParameterType == typeof(IMcpServer)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => GetRequestContext(args)?.Server, + }; + } - if (pi.ParameterType == typeof(IProgress)) + if (pi.ParameterType == typeof(IProgress)) + { + // Bind IProgress to the progress token in the request, + // if there is one. If we can't get one, return a nop progress. + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + { + var requestContent = GetRequestContext(args); + if (requestContent?.Server is { } server && + requestContent?.Params?.Meta?.ProgressToken is { } progressToken) { - // Bind IProgress to the progress token in the request, - // if there is one. If we can't get one, return a nop progress. - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - { - var requestContent = GetRequestContext(args); - if (requestContent?.Server is { } server && - requestContent?.Params?.Meta?.ProgressToken is { } progressToken) - { - return new TokenProgress(server, progressToken); - } - - return NullProgress.Instance; - }, - }; + return new TokenProgress(server, progressToken); } - if (options?.Services is { } services && - services.GetService() is { } ispis && - ispis.IsService(pi.ParameterType)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? - (pi.HasDefaultValue ? null : - throw new ArgumentException("No service of the requested type was found.")), - }; - } + return NullProgress.Instance; + }, + }; + } - if (pi.GetCustomAttribute() is { } keyedAttr) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? - (pi.HasDefaultValue ? null : - throw new ArgumentException("No service of the requested type was found.")), - }; - } + if (options?.Services is { } services && + services.GetService() is { } ispis && + ispis.IsService(pi.ParameterType)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + + if (pi.GetCustomAttribute() is { } keyedAttr) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } - return default; + return default; - static RequestContext? GetRequestContext(AIFunctionArguments args) - { - if (args.Context?.TryGetValue(typeof(RequestContext), out var orc) is true && - orc is RequestContext requestContext) - { - return requestContext; - } + static RequestContext? GetRequestContext(AIFunctionArguments args) + { + if (args.Context?.TryGetValue(typeof(RequestContext), out var orc) is true && + orc is RequestContext requestContext) + { + return requestContext; + } - return null; - } - }, - JsonSchemaCreateOptions = options?.SchemaCreateOptions, - }; + return null; + } + }, + JsonSchemaCreateOptions = options?.SchemaCreateOptions, + }; /// Creates an that wraps the specified . public static new AIFunctionMcpServerTool Create(AIFunction function, McpServerToolCreateOptions? options) @@ -182,10 +182,10 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( if (options is not null) { if (options.Title is not null || - options.Idempotent is not null || - options.Destructive is not null || - options.OpenWorld is not null || - options.ReadOnly is not null) + options.Idempotent is not null || + options.Destructive is not null || + options.OpenWorld is not null || + options.ReadOnly is not null) { tool.Annotations = new() { @@ -255,7 +255,7 @@ private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider /// public override async ValueTask InvokeAsync( - RequestContext request, CancellationToken cancellationToken = default) + RequestContext request, CancellationToken cancellationToken = default) { Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); @@ -282,12 +282,11 @@ public override async ValueTask InvokeAsync( } catch (Exception e) when (e is not OperationCanceledException) { - _logger.LogError(e, "Error invoking AIFunction tool '{ToolName}' with arguments '{Args}'.", - request.Params?.Name, string.Join(",", request.Params?.Arguments?.Keys ?? Array.Empty())); + ToolCallError(request.Params?.Name ?? string.Empty, e); string errorMessage = e is McpException ? - $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : - $"An error occurred invoking '{request.Params?.Name}'."; + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'."; return new() { @@ -336,10 +335,10 @@ public override async ValueTask InvokeAsync( _ => new() { Content = [new() - { - Text = JsonSerializer.Serialize(result, AIFunction.JsonSerializerOptions.GetTypeInfo(typeof(object))), - Type = "text" - }] +{ +Text = JsonSerializer.Serialize(result, AIFunction.JsonSerializerOptions.GetTypeInfo(typeof(object))), +Type = "text" +}] }, }; } @@ -367,4 +366,7 @@ private static CallToolResponse ConvertAIContentEnumerableToCallToolResponse(IEn IsError = allErrorContent && hasAny }; } + + [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] + private partial void ToolCallError(string toolName, Exception exception); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index cb98d9bce..db6f1cde4 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1,7 +1,9 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; using Moq; using System.Reflection; using System.Text.Json; @@ -381,6 +383,45 @@ public async Task SupportsSchemaCreateOptions() ); } + [Fact] + public async Task ToolCallError_LogsErrorMessage() + { + // Arrange + var mockLoggerProvider = new MockLoggerProvider(); + var loggerFactory = new LoggerFactory(new[] { mockLoggerProvider }); + var services = new ServiceCollection(); + services.AddSingleton(loggerFactory); + var serviceProvider = services.BuildServiceProvider(); + + var toolName = "tool-that-throws"; + var exceptionMessage = "Test exception message"; + + McpServerTool tool = McpServerTool.Create(() => + { + throw new InvalidOperationException(exceptionMessage); + }, new() { Name = toolName, Services = serviceProvider }); + + var mockServer = new Mock(); + var request = new RequestContext(mockServer.Object) + { + Params = new CallToolRequestParams() { Name = toolName }, + Services = serviceProvider + }; + + // Act + var result = await tool.InvokeAsync(request, TestContext.Current.CancellationToken); + + // Assert + Assert.True(result.IsError); + Assert.Single(result.Content); + Assert.Equal($"An error occurred invoking '{toolName}'.", result.Content[0].Text); + + var errorLog = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); + Assert.Equal($"\"{toolName}\" threw an unhandled exception.", errorLog.Message); + Assert.IsType(errorLog.Exception); + Assert.Equal(exceptionMessage, errorLog.Exception.Message); + } + private sealed class MyService; private class DisposableToolType : IDisposable From 4e3f411834be7ada7663b6ad722196d91de9bd13 Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Fri, 23 May 2025 16:18:58 -0500 Subject: [PATCH 3/4] whitespace cleanup --- .../Server/AIFunctionMcpServerTool.cs | 234 +++++++++--------- 1 file changed, 117 insertions(+), 117 deletions(-) diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index e6137d8d7..6b7590502 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -19,11 +19,11 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool /// Creates an instance for a method, specified via a instance. /// public static new AIFunctionMcpServerTool Create( - Delegate method, - McpServerToolCreateOptions? options) + Delegate method, + McpServerToolCreateOptions? options) { Throw.IfNull(method); - + options = DeriveOptions(method.Method, options); return Create(method.Method, method.Target, options); @@ -33,26 +33,26 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool /// Creates an instance for a method, specified via a instance. /// public static new AIFunctionMcpServerTool Create( - MethodInfo method, - object? target, - McpServerToolCreateOptions? options) + MethodInfo method, + object? target, + McpServerToolCreateOptions? options) { Throw.IfNull(method); options = DeriveOptions(method, options); return Create( - AIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, options)), - options); + AIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, options)), + options); } /// /// Creates an instance for a method, specified via a instance. /// public static new AIFunctionMcpServerTool Create( - MethodInfo method, - Func, object> createTargetFunc, - McpServerToolCreateOptions? options) + MethodInfo method, + Func, object> createTargetFunc, + McpServerToolCreateOptions? options) { Throw.IfNull(method); Throw.IfNull(createTargetFunc); @@ -60,112 +60,112 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool options = DeriveOptions(method, options); return Create( - AIFunctionFactory.Create(method, args => - { - var request = (RequestContext)args.Context![typeof(RequestContext)]!; - return createTargetFunc(request); - }, CreateAIFunctionFactoryOptions(method, options)), - options); + AIFunctionFactory.Create(method, args => + { + var request = (RequestContext)args.Context![typeof(RequestContext)]!; + return createTargetFunc(request); + }, CreateAIFunctionFactoryOptions(method, options)), + options); } // TODO: Fix the need for this suppression. [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2111:ReflectionToDynamicallyAccessedMembers", - Justification = "AIFunctionFactory ensures that the Type passed to AIFunctionFactoryOptions.CreateInstance has public constructors preserved")] + Justification = "AIFunctionFactory ensures that the Type passed to AIFunctionFactoryOptions.CreateInstance has public constructors preserved")] internal static Func GetCreateInstanceFunc() => - static ([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] type, args) => args.Services is { } services ? - ActivatorUtilities.CreateInstance(services, type) : - Activator.CreateInstance(type)!; + static ([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] type, args) => args.Services is { } services ? + ActivatorUtilities.CreateInstance(services, type) : + Activator.CreateInstance(type)!; private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( - MethodInfo method, McpServerToolCreateOptions? options) => - new() - { - Name = options?.Name ?? method.GetCustomAttribute()?.Name, - Description = options?.Description, - MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), - SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions, - ConfigureParameterBinding = pi => - { - if (pi.ParameterType == typeof(RequestContext)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args), - }; - } - - if (pi.ParameterType == typeof(IMcpServer)) + MethodInfo method, McpServerToolCreateOptions? options) => + new() { - return new() + Name = options?.Name ?? method.GetCustomAttribute()?.Name, + Description = options?.Description, + MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), + SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions, + ConfigureParameterBinding = pi => { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args)?.Server, - }; - } + if (pi.ParameterType == typeof(RequestContext)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => GetRequestContext(args), + }; + } - if (pi.ParameterType == typeof(IProgress)) - { - // Bind IProgress to the progress token in the request, - // if there is one. If we can't get one, return a nop progress. - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - { - var requestContent = GetRequestContext(args); - if (requestContent?.Server is { } server && - requestContent?.Params?.Meta?.ProgressToken is { } progressToken) + if (pi.ParameterType == typeof(IMcpServer)) { - return new TokenProgress(server, progressToken); + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => GetRequestContext(args)?.Server, + }; } - return NullProgress.Instance; - }, - }; - } + if (pi.ParameterType == typeof(IProgress)) + { + // Bind IProgress to the progress token in the request, + // if there is one. If we can't get one, return a nop progress. + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + { + var requestContent = GetRequestContext(args); + if (requestContent?.Server is { } server && + requestContent?.Params?.Meta?.ProgressToken is { } progressToken) + { + return new TokenProgress(server, progressToken); + } + + return NullProgress.Instance; + }, + }; + } - if (options?.Services is { } services && - services.GetService() is { } ispis && - ispis.IsService(pi.ParameterType)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? - (pi.HasDefaultValue ? null : - throw new ArgumentException("No service of the requested type was found.")), - }; - } + if (options?.Services is { } services && + services.GetService() is { } ispis && + ispis.IsService(pi.ParameterType)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } - if (pi.GetCustomAttribute() is { } keyedAttr) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? - (pi.HasDefaultValue ? null : - throw new ArgumentException("No service of the requested type was found.")), - }; - } + if (pi.GetCustomAttribute() is { } keyedAttr) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } - return default; + return default; - static RequestContext? GetRequestContext(AIFunctionArguments args) - { - if (args.Context?.TryGetValue(typeof(RequestContext), out var orc) is true && - orc is RequestContext requestContext) - { - return requestContext; - } + static RequestContext? GetRequestContext(AIFunctionArguments args) + { + if (args.Context?.TryGetValue(typeof(RequestContext), out var orc) is true && + orc is RequestContext requestContext) + { + return requestContext; + } - return null; - } - }, - JsonSchemaCreateOptions = options?.SchemaCreateOptions, - }; + return null; + } + }, + JsonSchemaCreateOptions = options?.SchemaCreateOptions, + }; /// Creates an that wraps the specified . public static new AIFunctionMcpServerTool Create(AIFunction function, McpServerToolCreateOptions? options) @@ -176,16 +176,16 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { Name = options?.Name ?? function.Name, Description = options?.Description ?? function.Description, - InputSchema = function.JsonSchema, + InputSchema = function.JsonSchema, }; if (options is not null) { if (options.Title is not null || - options.Idempotent is not null || - options.Destructive is not null || - options.OpenWorld is not null || - options.ReadOnly is not null) + options.Idempotent is not null || + options.Destructive is not null || + options.OpenWorld is not null || + options.ReadOnly is not null) { tool.Annotations = new() { @@ -255,7 +255,7 @@ private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider /// public override async ValueTask InvokeAsync( - RequestContext request, CancellationToken cancellationToken = default) + RequestContext request, CancellationToken cancellationToken = default) { Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); @@ -285,8 +285,8 @@ public override async ValueTask InvokeAsync( ToolCallError(request.Params?.Name ?? string.Empty, e); string errorMessage = e is McpException ? - $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : - $"An error occurred invoking '{request.Params?.Name}'."; + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'."; return new() { @@ -307,38 +307,38 @@ public override async ValueTask InvokeAsync( { Content = [] }, - + string text => new() { Content = [new() { Text = text, Type = "text" }] }, - + Content content => new() { Content = [content] }, - + IEnumerable texts => new() { Content = [.. texts.Select(x => new Content() { Type = "text", Text = x ?? string.Empty })] }, - + IEnumerable contentItems => ConvertAIContentEnumerableToCallToolResponse(contentItems), - + IEnumerable contents => new() { Content = [.. contents] }, - + CallToolResponse callToolResponse => callToolResponse, _ => new() { Content = [new() -{ -Text = JsonSerializer.Serialize(result, AIFunction.JsonSerializerOptions.GetTypeInfo(typeof(object))), -Type = "text" -}] + { + Text = JsonSerializer.Serialize(result, AIFunction.JsonSerializerOptions.GetTypeInfo(typeof(object))), + Type = "text" + }] }, }; } From b39aaf512bc247120838be59f534910d97ff8e36 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 23 May 2025 14:29:35 -0700 Subject: [PATCH 4/4] Update src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs --- src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 6b7590502..7f91186b1 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -247,7 +247,7 @@ private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider { AIFunction = function; ProtocolTool = tool; - _logger = serviceProvider?.GetService()?.CreateLogger() ?? (ILogger)NullLogger.Instance; + _logger = serviceProvider?.GetService()?.CreateLogger() ?? (ILogger)NullLogger.Instance; } ///