Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions src/ModelContextProtocol.Core/McpSessionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
// Now that the request has been sent, register for cancellation. If we registered before,
// a cancellation request could arrive before the server knew about that request ID, in which
// case the server could ignore it.
LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id);
string? target = GetRequestTarget(request);
LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id, toolName: target);
JsonRpcMessage? response;
using (var registration = RegisterCancellation(cancellationToken, request))
{
Expand All @@ -451,7 +452,7 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc

if (response is JsonRpcError error)
{
LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code);
LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code, toolName: target);
throw new McpProtocolException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code);
}

Expand All @@ -464,11 +465,11 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc

if (_logger.IsEnabled(LogLevel.Trace))
{
LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null");
LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null", toolName: target);
}
else
{
LogRequestResponseReceived(EndpointName, request.Method);
LogRequestResponseReceived(EndpointName, request.Method, toolName: target);
}

return success;
Expand Down Expand Up @@ -769,6 +770,29 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>
return null;
}

/// <summary>
/// Extracts the target identifier (tool name, prompt name, or resource URI) from a request.
/// </summary>
/// <param name="request">The JSON-RPC request.</param>
/// <returns>The target identifier if available; otherwise, null.</returns>
private static string? GetRequestTarget(JsonRpcRequest request)
{
if (request.Params is not JsonObject paramsObj)
{
return null;
}

return request.Method switch
{
RequestMethods.ToolsCall => GetStringProperty(paramsObj, "name"),
RequestMethods.PromptsGet => GetStringProperty(paramsObj, "name"),
RequestMethods.ResourcesRead => GetStringProperty(paramsObj, "uri"),
RequestMethods.ResourcesSubscribe => GetStringProperty(paramsObj, "uri"),
RequestMethods.ResourcesUnsubscribe => GetStringProperty(paramsObj, "uri"),
_ => null
};
}

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} message processing canceled.")]
private partial void LogEndpointMessageProcessingCanceled(string endpointName);

Expand All @@ -784,8 +808,8 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")]
private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId);

[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")]
private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode);
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}' (tool: '{ToolName}'): {ErrorMessage} ({ErrorCode}).")]
private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode, string? toolName = null);

[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received invalid response for method '{Method}'.")]
private partial void LogSendingRequestInvalidResponseType(string endpointName, string method);
Expand All @@ -799,11 +823,11 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} canceled request '{RequestId}' per client notification. Reason: '{Reason}'.")]
private partial void LogRequestCanceled(string endpointName, RequestId requestId, string? reason);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method}")]
private partial void LogRequestResponseReceived(string endpointName, string method);
[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method} (tool: '{ToolName}')")]
private partial void LogRequestResponseReceived(string endpointName, string method, string? toolName = null);

[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method}. Response: '{Response}'.")]
private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response);
[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method} (tool: '{ToolName}'). Response: '{Response}'.")]
private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response, string? toolName = null);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} read {MessageType} message from channel.")]
private partial void LogMessageRead(string endpointName, string messageType);
Expand All @@ -820,8 +844,8 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received request for method '{Method}', but no handler is available.")]
private partial void LogNoHandlerFoundForRequest(string endpointName, string method);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}'.")]
private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId);
[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}' (tool: '{ToolName}').")]
private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId, string? toolName = null);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending message.")]
private partial void LogSendingMessage(string endpointName);
Expand Down
2 changes: 2 additions & 0 deletions tests/Common/Utils/MockLoggerProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace ModelContextProtocol.Tests.Utils;
public class MockLoggerProvider() : ILoggerProvider
{
public ConcurrentQueue<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception)> LogMessages { get; } = [];
public ConcurrentQueue<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception, object? State)> LogMessagesWithState { get; } = [];

public ILogger CreateLogger(string categoryName)
{
Expand All @@ -22,6 +23,7 @@ public void Log<TState>(
LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func<TState, Exception?, string> formatter)
{
mockProvider.LogMessages.Enqueue((category, logLevel, eventId, formatter(state, exception), exception));
mockProvider.LogMessagesWithState.Enqueue((category, logLevel, eventId, formatter(state, exception), exception, state));
}

public bool IsEnabled(LogLevel logLevel) => true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@ namespace ModelContextProtocol.Tests.Configuration;

public partial class McpServerBuilderExtensionsToolsTests : ClientServerTestBase
{
private MockLoggerProvider _mockLoggerProvider = new();

public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper)
: base(testOutputHelper)
{
// Configure LoggerFactory to use Debug level and add MockLoggerProvider
LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
{
builder.AddProvider(XunitLoggerProvider);
builder.AddProvider(_mockLoggerProvider);
builder.SetMinimumLevel(LogLevel.Debug);
});
}

private MockLoggerProvider _mockLoggerProvider = new();

protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder)
{
mcpServerBuilder
Expand Down Expand Up @@ -733,6 +740,86 @@ await client.SendNotificationAsync(
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await invokeTask);
}

[Fact]
public async Task ToolName_Captured_In_Structured_Logging()
{
await using McpClient client = await CreateMcpClientForServer();

// Call a tool that will succeed
var result = await client.CallToolAsync(
"echo",
new Dictionary<string, object?> { ["message"] = "test" },
cancellationToken: TestContext.Current.CancellationToken);

Assert.NotNull(result);

// Verify that the tool name is captured in structured logging
// The LogMessagesWithState should contain log entries with tool name in the state
var relevantLogs = _mockLoggerProvider.LogMessagesWithState
.Where(m => m.Category == "ModelContextProtocol.Client.McpClient" &&
m.Message.Contains("tools/call"))
.ToList();

Assert.NotEmpty(relevantLogs);

// Check that at least one log entry has the tool name in its structured state
// This demonstrates how users can extract the tool name from TState in a custom ILoggerProvider
// The State object is IReadOnlyList<KeyValuePair<string, object?>> which contains
// structured logging parameters like "ToolName", "Method", "EndpointName", etc.
bool foundToolName = relevantLogs.Any(log =>
{
if (log.State is IReadOnlyList<KeyValuePair<string, object?>> stateList)
{
return stateList.Any(kvp =>
kvp.Key == "ToolName" &&
kvp.Value?.ToString() == "echo");
}
return false;
});

Assert.True(foundToolName, "Tool name 'echo' was not found in structured logging state");
}

[Fact]
public async Task ToolName_Captured_In_Structured_Logging_OnToolError()
{
await using McpClient client = await CreateMcpClientForServer();

// Call a tool that will error - note that tool errors are returned as CallToolResult with IsError=true,
// not thrown as exceptions per the MCP spec
var result = await client.CallToolAsync(
"throw_exception",
cancellationToken: TestContext.Current.CancellationToken);

// Verify the tool error was returned properly
Assert.NotNull(result);
Assert.True(result.IsError);

// Verify that the tool name is captured in structured logging
// even when the tool encounters an error
var relevantLogs = _mockLoggerProvider.LogMessagesWithState
.Where(m => m.LogLevel == LogLevel.Debug &&
(m.Message.Contains("waiting for response") || m.Message.Contains("response received")) &&
m.Message.Contains("tools/call"))
.ToList();

Assert.NotEmpty(relevantLogs);

// Check that at least one log entry has the tool name in its structured state
bool foundToolName = relevantLogs.Any(log =>
{
if (log.State is IReadOnlyList<KeyValuePair<string, object?>> stateList)
{
return stateList.Any(kvp =>
kvp.Key == "ToolName" &&
kvp.Value?.ToString() == "throw_exception");
}
return false;
});

Assert.True(foundToolName, "Tool name 'throw_exception' was not found in structured logging state");
}

[McpServerToolType]
public sealed class EchoTool(ObjectWithId objectFromDI)
{
Expand Down