Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Security.Cryptography;

namespace ModelContextProtocol.AspNetCore;
Expand Down Expand Up @@ -61,18 +62,19 @@ public async Task HandleSseRequestAsync(HttpContext context)
var httpMcpSession = new HttpMcpSession(transport, context.User);
if (!_sessions.TryAdd(sessionId, httpMcpSession))
{
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}

var mcpServerOptions = mcpServerOptionsSnapshot.Value;
if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions)
{
mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName);
await configureSessionOptions(context, mcpServerOptions, cancellationToken);
Debug.Fail("Unreachable given good entropy!");
throw new InvalidOperationException($"Session with ID '{sessionId}' has already been created.");
}

try
{
var mcpServerOptions = mcpServerOptionsSnapshot.Value;
if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions)
{
mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName);
await configureSessionOptions(context, mcpServerOptions, cancellationToken);
}

var transportTask = transport.RunAsync(cancellationToken);

try
Expand Down
30 changes: 22 additions & 8 deletions src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
Expand All @@ -10,7 +9,7 @@
namespace ModelContextProtocol.Client;

/// <inheritdoc/>
internal sealed class McpClient : McpEndpoint, IMcpClient
internal sealed partial class McpClient : McpEndpoint, IMcpClient
{
private static Implementation DefaultImplementation { get; } = new()
{
Expand Down Expand Up @@ -133,9 +132,12 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
cancellationToken: initializationCts.Token).ConfigureAwait(false);

// Store server information
_logger.ServerCapabilitiesReceived(EndpointName,
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));
if (_logger.IsEnabled(LogLevel.Information))
{
LogServerCapabilitiesReceived(EndpointName,
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));
}

_serverCapabilities = initializeResponse.Capabilities;
_serverInfo = initializeResponse.ServerInfo;
Expand All @@ -144,7 +146,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
// Validate protocol version
if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)
{
_logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
LogServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
throw new McpException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");
}

Expand All @@ -155,13 +157,13 @@ await SendMessageAsync(
}
catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested)
{
_logger.ClientInitializationTimeout(EndpointName);
LogClientInitializationTimeout(EndpointName);
throw new McpException("Initialization timed out", oce);
}
}
catch (Exception e)
{
_logger.ClientInitializationError(EndpointName, e);
LogClientInitializationError(EndpointName, e);
await DisposeAsync().ConfigureAwait(false);
throw;
}
Expand All @@ -188,4 +190,16 @@ public override async ValueTask DisposeUnsynchronizedAsync()
}
}
}

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")]
private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo);

[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization error.")]
private partial void LogClientInitializationError(string endpointName, Exception exception);

[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization timed out.")]
private partial void LogClientInitializationTimeout(string endpointName);

[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client protocol version mismatch with server. Expected '{Expected}', received '{Received}'.")]
private partial void LogServerProtocolVersionMismatch(string endpointName, string expected, string received);
}
19 changes: 10 additions & 9 deletions src/ModelContextProtocol/Client/McpClientFactory.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Utils;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

namespace ModelContextProtocol.Client;

Expand All @@ -14,7 +12,7 @@ namespace ModelContextProtocol.Client;
/// that connect to MCP servers. It handles the creation and connection
/// of appropriate implementations through the supplied transport.
/// </remarks>
public static class McpClientFactory
public static partial class McpClientFactory
{
/// <summary>Creates an <see cref="IMcpClient"/>, connecting it to the specified server.</summary>
/// <param name="clientTransport">The transport instance used to communicate with the server.</param>
Expand All @@ -35,21 +33,24 @@ public static async Task<IMcpClient> CreateAsync(
{
Throw.IfNull(clientTransport);

string endpointName = clientTransport.Name;
var logger = loggerFactory?.CreateLogger(typeof(McpClientFactory)) ?? NullLogger.Instance;
logger.CreatingClient(endpointName);

McpClient client = new(clientTransport, clientOptions, loggerFactory);
try
{
await client.ConnectAsync(cancellationToken).ConfigureAwait(false);
logger.ClientCreated(endpointName);
return client;
if (loggerFactory?.CreateLogger(typeof(McpClientFactory)) is ILogger logger)
{
logger.LogClientCreated(client.EndpointName);
}
}
catch
{
await client.DisposeAsync().ConfigureAwait(false);
throw;
}

return client;
}

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")]
private static partial void LogClientCreated(this ILogger logger, string endpointName);
}
Loading