diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index fcd7980d9..d407f4265 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -442,7 +442,8 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc // Now that the request has been sent, register for cancellation. If we registered before, // a cancellation request could arrive before the server knew about that request ID, in which // case the server could ignore it. - LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id); + string? target = GetRequestTarget(request); + LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id, toolName: target); JsonRpcMessage? response; using (var registration = RegisterCancellation(cancellationToken, request)) { @@ -451,7 +452,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc if (response is JsonRpcError error) { - LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); + LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code, toolName: target); throw new McpProtocolException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); } @@ -464,11 +465,11 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc if (_logger.IsEnabled(LogLevel.Trace)) { - LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null"); + LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null", toolName: target); } else { - LogRequestResponseReceived(EndpointName, request.Method); + LogRequestResponseReceived(EndpointName, request.Method, toolName: target); } return success; @@ -769,6 +770,29 @@ private static TimeSpan GetElapsed(long startingTimestamp) => return null; } + /// + /// Extracts the target identifier (tool name, prompt name, or resource URI) from a request. + /// + /// The JSON-RPC request. + /// The target identifier if available; otherwise, null. + private static string? GetRequestTarget(JsonRpcRequest request) + { + if (request.Params is not JsonObject paramsObj) + { + return null; + } + + return request.Method switch + { + RequestMethods.ToolsCall => GetStringProperty(paramsObj, "name"), + RequestMethods.PromptsGet => GetStringProperty(paramsObj, "name"), + RequestMethods.ResourcesRead => GetStringProperty(paramsObj, "uri"), + RequestMethods.ResourcesSubscribe => GetStringProperty(paramsObj, "uri"), + RequestMethods.ResourcesUnsubscribe => GetStringProperty(paramsObj, "uri"), + _ => null + }; + } + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} message processing canceled.")] private partial void LogEndpointMessageProcessingCanceled(string endpointName); @@ -784,8 +808,8 @@ private static TimeSpan GetElapsed(long startingTimestamp) => [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")] private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId); - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")] - private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode); + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}' (tool: '{ToolName}'): {ErrorMessage} ({ErrorCode}).")] + private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode, string? toolName = null); [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received invalid response for method '{Method}'.")] private partial void LogSendingRequestInvalidResponseType(string endpointName, string method); @@ -799,11 +823,11 @@ private static TimeSpan GetElapsed(long startingTimestamp) => [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} canceled request '{RequestId}' per client notification. Reason: '{Reason}'.")] private partial void LogRequestCanceled(string endpointName, RequestId requestId, string? reason); - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method}")] - private partial void LogRequestResponseReceived(string endpointName, string method); + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method} (tool: '{ToolName}')")] + private partial void LogRequestResponseReceived(string endpointName, string method, string? toolName = null); - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method}. Response: '{Response}'.")] - private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response); + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method} (tool: '{ToolName}'). Response: '{Response}'.")] + private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response, string? toolName = null); [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} read {MessageType} message from channel.")] private partial void LogMessageRead(string endpointName, string messageType); @@ -820,8 +844,8 @@ private static TimeSpan GetElapsed(long startingTimestamp) => [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received request for method '{Method}', but no handler is available.")] private partial void LogNoHandlerFoundForRequest(string endpointName, string method); - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}'.")] - private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId); + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}' (tool: '{ToolName}').")] + private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId, string? toolName = null); [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending message.")] private partial void LogSendingMessage(string endpointName); diff --git a/tests/Common/Utils/MockLoggerProvider.cs b/tests/Common/Utils/MockLoggerProvider.cs index 14a0f401a..a56c71036 100644 --- a/tests/Common/Utils/MockLoggerProvider.cs +++ b/tests/Common/Utils/MockLoggerProvider.cs @@ -6,6 +6,7 @@ namespace ModelContextProtocol.Tests.Utils; public class MockLoggerProvider() : ILoggerProvider { public ConcurrentQueue<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception)> LogMessages { get; } = []; + public ConcurrentQueue<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception, object? State)> LogMessagesWithState { get; } = []; public ILogger CreateLogger(string categoryName) { @@ -22,6 +23,7 @@ public void Log( LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) { mockProvider.LogMessages.Enqueue((category, logLevel, eventId, formatter(state, exception), exception)); + mockProvider.LogMessagesWithState.Enqueue((category, logLevel, eventId, formatter(state, exception), exception, state)); } public bool IsEnabled(LogLevel logLevel) => true; diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 79bf6fe50..88c74e9b1 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -20,13 +20,20 @@ namespace ModelContextProtocol.Tests.Configuration; public partial class McpServerBuilderExtensionsToolsTests : ClientServerTestBase { + private MockLoggerProvider _mockLoggerProvider = new(); + public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { + // Configure LoggerFactory to use Debug level and add MockLoggerProvider + LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => + { + builder.AddProvider(XunitLoggerProvider); + builder.AddProvider(_mockLoggerProvider); + builder.SetMinimumLevel(LogLevel.Debug); + }); } - private MockLoggerProvider _mockLoggerProvider = new(); - protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) { mcpServerBuilder @@ -733,6 +740,86 @@ await client.SendNotificationAsync( await Assert.ThrowsAnyAsync(async () => await invokeTask); } + [Fact] + public async Task ToolName_Captured_In_Structured_Logging() + { + await using McpClient client = await CreateMcpClientForServer(); + + // Call a tool that will succeed + var result = await client.CallToolAsync( + "echo", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(result); + + // Verify that the tool name is captured in structured logging + // The LogMessagesWithState should contain log entries with tool name in the state + var relevantLogs = _mockLoggerProvider.LogMessagesWithState + .Where(m => m.Category == "ModelContextProtocol.Client.McpClient" && + m.Message.Contains("tools/call")) + .ToList(); + + Assert.NotEmpty(relevantLogs); + + // Check that at least one log entry has the tool name in its structured state + // This demonstrates how users can extract the tool name from TState in a custom ILoggerProvider + // The State object is IReadOnlyList> which contains + // structured logging parameters like "ToolName", "Method", "EndpointName", etc. + bool foundToolName = relevantLogs.Any(log => + { + if (log.State is IReadOnlyList> stateList) + { + return stateList.Any(kvp => + kvp.Key == "ToolName" && + kvp.Value?.ToString() == "echo"); + } + return false; + }); + + Assert.True(foundToolName, "Tool name 'echo' was not found in structured logging state"); + } + + [Fact] + public async Task ToolName_Captured_In_Structured_Logging_OnToolError() + { + await using McpClient client = await CreateMcpClientForServer(); + + // Call a tool that will error - note that tool errors are returned as CallToolResult with IsError=true, + // not thrown as exceptions per the MCP spec + var result = await client.CallToolAsync( + "throw_exception", + cancellationToken: TestContext.Current.CancellationToken); + + // Verify the tool error was returned properly + Assert.NotNull(result); + Assert.True(result.IsError); + + // Verify that the tool name is captured in structured logging + // even when the tool encounters an error + var relevantLogs = _mockLoggerProvider.LogMessagesWithState + .Where(m => m.LogLevel == LogLevel.Debug && + (m.Message.Contains("waiting for response") || m.Message.Contains("response received")) && + m.Message.Contains("tools/call")) + .ToList(); + + Assert.NotEmpty(relevantLogs); + + // Check that at least one log entry has the tool name in its structured state + bool foundToolName = relevantLogs.Any(log => + { + if (log.State is IReadOnlyList> stateList) + { + return stateList.Any(kvp => + kvp.Key == "ToolName" && + kvp.Value?.ToString() == "throw_exception"); + } + return false; + }); + + Assert.True(foundToolName, "Tool name 'throw_exception' was not found in structured logging state"); + } + [McpServerToolType] public sealed class EchoTool(ObjectWithId objectFromDI) {