diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 366eb23cd..7f91186b1 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -1,5 +1,7 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.ComponentModel; using System.Diagnostics.CodeAnalysis; @@ -9,8 +11,10 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . -internal sealed class AIFunctionMcpServerTool : McpServerTool +internal sealed partial class AIFunctionMcpServerTool : McpServerTool { + private readonly ILogger _logger; + /// /// Creates an instance for a method, specified via a instance. /// @@ -194,7 +198,7 @@ options.OpenWorld is not null || } } - return new AIFunctionMcpServerTool(function, tool); + return new AIFunctionMcpServerTool(function, tool, options?.Services); } private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpServerToolCreateOptions? options) @@ -239,10 +243,11 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerTool(AIFunction function, Tool tool) + private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider) { AIFunction = function; ProtocolTool = tool; + _logger = serviceProvider?.GetService()?.CreateLogger() ?? (ILogger)NullLogger.Instance; } /// @@ -277,6 +282,8 @@ public override async ValueTask InvokeAsync( } catch (Exception e) when (e is not OperationCanceledException) { + ToolCallError(request.Params?.Name ?? string.Empty, e); + string errorMessage = e is McpException ? $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : $"An error occurred invoking '{request.Params?.Name}'."; @@ -359,4 +366,7 @@ private static CallToolResponse ConvertAIContentEnumerableToCallToolResponse(IEn IsError = allErrorContent && hasAny }; } + + [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] + private partial void ToolCallError(string toolName, Exception exception); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index cb98d9bce..db6f1cde4 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1,7 +1,9 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; using Moq; using System.Reflection; using System.Text.Json; @@ -381,6 +383,45 @@ public async Task SupportsSchemaCreateOptions() ); } + [Fact] + public async Task ToolCallError_LogsErrorMessage() + { + // Arrange + var mockLoggerProvider = new MockLoggerProvider(); + var loggerFactory = new LoggerFactory(new[] { mockLoggerProvider }); + var services = new ServiceCollection(); + services.AddSingleton(loggerFactory); + var serviceProvider = services.BuildServiceProvider(); + + var toolName = "tool-that-throws"; + var exceptionMessage = "Test exception message"; + + McpServerTool tool = McpServerTool.Create(() => + { + throw new InvalidOperationException(exceptionMessage); + }, new() { Name = toolName, Services = serviceProvider }); + + var mockServer = new Mock(); + var request = new RequestContext(mockServer.Object) + { + Params = new CallToolRequestParams() { Name = toolName }, + Services = serviceProvider + }; + + // Act + var result = await tool.InvokeAsync(request, TestContext.Current.CancellationToken); + + // Assert + Assert.True(result.IsError); + Assert.Single(result.Content); + Assert.Equal($"An error occurred invoking '{toolName}'.", result.Content[0].Text); + + var errorLog = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); + Assert.Equal($"\"{toolName}\" threw an unhandled exception.", errorLog.Message); + Assert.IsType(errorLog.Exception); + Assert.Equal(exceptionMessage, errorLog.Exception.Message); + } + private sealed class MyService; private class DisposableToolType : IDisposable