diff --git a/samples/ProtectedMcpServer/Tools/HttpClientExt.cs b/samples/ProtectedMcpServer/Tools/HttpClientExt.cs deleted file mode 100644 index f7b2b549..00000000 --- a/samples/ProtectedMcpServer/Tools/HttpClientExt.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System.Text.Json; - -namespace ModelContextProtocol; - -internal static class HttpClientExt -{ - public static async Task ReadJsonDocumentAsync(this HttpClient client, string requestUri) - { - using var response = await client.GetAsync(requestUri); - response.EnsureSuccessStatusCode(); - return await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); - } -} \ No newline at end of file diff --git a/samples/ProtectedMcpServer/Tools/WeatherTools.cs b/samples/ProtectedMcpServer/Tools/WeatherTools.cs index 477463c8..94cc0389 100644 --- a/samples/ProtectedMcpServer/Tools/WeatherTools.cs +++ b/samples/ProtectedMcpServer/Tools/WeatherTools.cs @@ -21,9 +21,10 @@ public async Task GetAlerts( [Description("The US state to get alerts for. Use the 2 letter abbreviation for the state (e.g. NY).")] string state) { var client = _httpClientFactory.CreateClient("WeatherApi"); - using var jsonDocument = await client.ReadJsonDocumentAsync($"/alerts/active/area/{state}"); - var jsonElement = jsonDocument.RootElement; - var alerts = jsonElement.GetProperty("features").EnumerateArray(); + using var jsonDocument = await client.GetFromJsonAsync($"/alerts/active/area/{state}") + ?? throw new McpException("No JSON returned from alerts endpoint"); + + var alerts = jsonDocument.RootElement.GetProperty("features").EnumerateArray(); if (!alerts.Any()) { @@ -50,12 +51,14 @@ public async Task GetForecast( { var client = _httpClientFactory.CreateClient("WeatherApi"); var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}"); - using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl); - var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString() - ?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); - using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl); - var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray(); + using var locationDocument = await client.GetFromJsonAsync(pointUrl); + var forecastUrl = locationDocument?.RootElement.GetProperty("properties").GetProperty("forecast").GetString() + ?? throw new McpException($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); + + using var forecastDocument = await client.GetFromJsonAsync(forecastUrl); + var periods = forecastDocument?.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray() + ?? throw new McpException("No JSON returned from forecast endpoint"); return string.Join("\n---\n", periods.Select(period => $""" {period.GetProperty("name").GetString()} diff --git a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs index e02d4c32..61dc0a0e 100644 --- a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs +++ b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs @@ -43,9 +43,9 @@ public static async Task GetForecast( [Description("Longitude of the location.")] double longitude) { var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}"); - using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl); - var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString() - ?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); + using var locationDocument = await client.ReadJsonDocumentAsync(pointUrl); + var forecastUrl = locationDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString() + ?? throw new McpException($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl); var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray(); diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 0cdc4e37..2d6b29fd 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -23,6 +23,7 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder { ArgumentNullException.ThrowIfNull(builder); + builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); builder.Services.AddHostedService(); diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs deleted file mode 100644 index 1456ce56..00000000 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ /dev/null @@ -1,85 +0,0 @@ -using ModelContextProtocol.AspNetCore.Stateless; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.Security.Claims; - -namespace ModelContextProtocol.AspNetCore; - -internal sealed class HttpMcpSession( - string sessionId, - TTransport transport, - UserIdClaim? userId, - TimeProvider timeProvider) : IAsyncDisposable - where TTransport : ITransport -{ - private int _referenceCount; - private int _getRequestStarted; - private CancellationTokenSource _disposeCts = new(); - - public string Id { get; } = sessionId; - public TTransport Transport { get; } = transport; - public UserIdClaim? UserIdClaim { get; } = userId; - - public CancellationToken SessionClosed => _disposeCts.Token; - - public bool IsActive => !SessionClosed.IsCancellationRequested && _referenceCount > 0; - public long LastActivityTicks { get; private set; } = timeProvider.GetTimestamp(); - - private TimeProvider TimeProvider => timeProvider; - - public IMcpServer? Server { get; set; } - public Task? ServerRunTask { get; set; } - - public IDisposable AcquireReference() - { - Interlocked.Increment(ref _referenceCount); - return new UnreferenceDisposable(this); - } - - public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0; - - public async ValueTask DisposeAsync() - { - try - { - await _disposeCts.CancelAsync(); - - if (ServerRunTask is not null) - { - await ServerRunTask; - } - } - catch (OperationCanceledException) - { - } - finally - { - try - { - if (Server is not null) - { - await Server.DisposeAsync(); - } - } - finally - { - await Transport.DisposeAsync(); - _disposeCts.Dispose(); - } - } - } - - public bool HasSameUserId(ClaimsPrincipal user) - => UserIdClaim == StreamableHttpHandler.GetUserIdClaim(user); - - private sealed class UnreferenceDisposable(HttpMcpSession session) : IDisposable - { - public void Dispose() - { - if (Interlocked.Decrement(ref session._referenceCount) == 0) - { - session.LastActivityTicks = session.TimeProvider.GetTimestamp(); - } - } - } -} diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 2a34a17a..94de9cb9 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -66,9 +66,9 @@ public class HttpServerTransportOptions /// Past this limit, the server will log a critical error and terminate the oldest idle sessions even if they have not reached /// their until the idle session count is below this limit. Clients that keep their session open by /// keeping a GET request open will not count towards this limit. - /// Defaults to 100,000 sessions. + /// Defaults to 10,000 sessions. /// - public int MaxIdleSessionCount { get; set; } = 100_000; + public int MaxIdleSessionCount { get; set; } = 10_000; /// /// Used for testing the . diff --git a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs index c4a5f11e..a4ae569b 100644 --- a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs +++ b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs @@ -1,18 +1,16 @@ -using System.Runtime.InteropServices; -using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using ModelContextProtocol.Server; namespace ModelContextProtocol.AspNetCore; internal sealed partial class IdleTrackingBackgroundService( - StreamableHttpHandler handler, + StatefulSessionManager sessions, IOptions options, IHostApplicationLifetime appLifetime, ILogger logger) : BackgroundService { - // The compiler will complain about the parameter being unused otherwise despite the source generator. + // Workaround for https://github.com/dotnet/runtime/issues/91121. This is fixed in .NET 9 and later. private readonly ILogger _logger = logger; protected override async Task ExecuteAsync(CancellationToken stoppingToken) @@ -30,65 +28,9 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) var timeProvider = options.Value.TimeProvider; using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider); - var idleTimeoutTicks = options.Value.IdleTimeout.Ticks; - var maxIdleSessionCount = options.Value.MaxIdleSessionCount; - - // Create two lists that will be reused between runs. - // This assumes that the number of idle sessions is not breached frequently. - // If the idle sessions often breach the maximum, a priority queue could be considered. - var idleSessionsTimestamps = new List(); - var idleSessionSessionIds = new List(); - while (!stoppingToken.IsCancellationRequested && await timer.WaitForNextTickAsync(stoppingToken)) { - var idleActivityCutoff = idleTimeoutTicks switch - { - < 0 => long.MinValue, - var ticks => timeProvider.GetTimestamp() - ticks, - }; - - foreach (var (_, session) in handler.Sessions) - { - if (session.IsActive || session.SessionClosed.IsCancellationRequested) - { - // There's a request currently active or the session is already being closed. - continue; - } - - if (session.LastActivityTicks < idleActivityCutoff) - { - RemoveAndCloseSession(session.Id); - continue; - } - - // Add the timestamp and the session - idleSessionsTimestamps.Add(session.LastActivityTicks); - idleSessionSessionIds.Add(session.Id); - - // Emit critical log at most once every 5 seconds the idle count it exceeded, - // since the IdleTimeout will no longer be respected. - if (idleSessionsTimestamps.Count == maxIdleSessionCount + 1) - { - LogMaxSessionIdleCountExceeded(maxIdleSessionCount); - } - } - - if (idleSessionsTimestamps.Count > maxIdleSessionCount) - { - var timestamps = CollectionsMarshal.AsSpan(idleSessionsTimestamps); - - // Sort only if the maximum is breached and sort solely by the timestamp. Sort both collections. - timestamps.Sort(CollectionsMarshal.AsSpan(idleSessionSessionIds)); - - var sessionsToPrune = CollectionsMarshal.AsSpan(idleSessionSessionIds)[..^maxIdleSessionCount]; - foreach (var id in sessionsToPrune) - { - RemoveAndCloseSession(id); - } - } - - idleSessionsTimestamps.Clear(); - idleSessionSessionIds.Clear(); + await sessions.PruneIdleSessionsAsync(stoppingToken); } } catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested) @@ -98,17 +40,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { try { - List disposeSessionTasks = []; - - foreach (var (sessionKey, _) in handler.Sessions) - { - if (handler.Sessions.TryRemove(sessionKey, out var session)) - { - disposeSessionTasks.Add(DisposeSessionAsync(session)); - } - } - - await Task.WhenAll(disposeSessionTasks); + await sessions.DisposeAllSessionsAsync(); } finally { @@ -123,39 +55,6 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) } } - private void RemoveAndCloseSession(string sessionId) - { - if (!handler.Sessions.TryRemove(sessionId, out var session)) - { - return; - } - - LogSessionIdle(session.Id); - // Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown. - _ = DisposeSessionAsync(session); - } - - private async Task DisposeSessionAsync(HttpMcpSession session) - { - try - { - await session.DisposeAsync(); - } - catch (Exception ex) - { - LogSessionDisposeError(session.Id, ex); - } - } - - [LoggerMessage(Level = LogLevel.Information, Message = "Closing idle session {sessionId}.")] - private partial void LogSessionIdle(string sessionId); - - [LoggerMessage(Level = LogLevel.Error, Message = "Error disposing session {sessionId}.")] - private partial void LogSessionDisposeError(string sessionId, Exception ex); - - [LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded maximum of {maxIdleSessionCount} idle sessions. Now closing sessions active more recently than configured IdleTimeout.")] - private partial void LogMaxSessionIdleCountExceeded(int maxIdleSessionCount); - [LoggerMessage(Level = LogLevel.Critical, Message = "The IdleTrackingBackgroundService has stopped unexpectedly.")] private partial void IdleTrackingBackgroundServiceStoppedUnexpectedly(); } \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs index c5ac5a94..6ed72fb6 100644 --- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -16,7 +16,7 @@ internal sealed class SseHandler( IHostApplicationLifetime hostApplicationLifetime, ILoggerFactory loggerFactory) { - private readonly ConcurrentDictionary> _sessions = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); public async Task HandleSseRequestAsync(HttpContext context) { @@ -34,9 +34,9 @@ public async Task HandleSseRequestAsync(HttpContext context) await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}", sessionId); var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User); - await using var httpMcpSession = new HttpMcpSession(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider); + var sseSession = new SseSession(transport, userIdClaim); - if (!_sessions.TryAdd(sessionId, httpMcpSession)) + if (!_sessions.TryAdd(sessionId, sseSession)) { throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); } @@ -55,12 +55,10 @@ public async Task HandleSseRequestAsync(HttpContext context) try { await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); - httpMcpSession.Server = mcpServer; context.Features.Set(mcpServer); var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? StreamableHttpHandler.RunSessionAsync; - httpMcpSession.ServerRunTask = runSessionAsync(context, mcpServer, cancellationToken); - await httpMcpSession.ServerRunTask; + await runSessionAsync(context, mcpServer, cancellationToken); } finally { @@ -87,13 +85,13 @@ public async Task HandleMessageRequestAsync(HttpContext context) return; } - if (!_sessions.TryGetValue(sessionId.ToString(), out var httpMcpSession)) + if (!_sessions.TryGetValue(sessionId.ToString(), out var sseSession)) { await Results.BadRequest($"Session ID not found.").ExecuteAsync(context); return; } - if (!httpMcpSession.HasSameUserId(context.User)) + if (sseSession.UserId != StreamableHttpHandler.GetUserIdClaim(context.User)) { await Results.Forbid().ExecuteAsync(context); return; @@ -106,8 +104,10 @@ public async Task HandleMessageRequestAsync(HttpContext context) return; } - await httpMcpSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted); + await sseSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted); context.Response.StatusCode = StatusCodes.Status202Accepted; await context.Response.WriteAsync("Accepted"); } + + private record SseSession(SseResponseStreamTransport Transport, UserIdClaim? UserId); } diff --git a/src/ModelContextProtocol.AspNetCore/StatefulSessionManager.cs b/src/ModelContextProtocol.AspNetCore/StatefulSessionManager.cs new file mode 100644 index 00000000..960488af --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StatefulSessionManager.cs @@ -0,0 +1,243 @@ +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace ModelContextProtocol.AspNetCore; + +internal sealed partial class StatefulSessionManager( + IOptions httpServerTransportOptions, + ILogger logger) +{ + // Workaround for https://github.com/dotnet/runtime/issues/91121. This is fixed in .NET 9 and later. + private readonly ILogger _logger = logger; + + private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); + + private readonly TimeProvider _timeProvider = httpServerTransportOptions.Value.TimeProvider; + private readonly TimeSpan _idleTimeout = httpServerTransportOptions.Value.IdleTimeout; + private readonly long _idleTimeoutTicks = httpServerTransportOptions.Value.IdleTimeout.Ticks; + private readonly int _maxIdleSessionCount = httpServerTransportOptions.Value.MaxIdleSessionCount; + + private readonly object _idlePruningLock = new(); + private readonly List _idleTimestamps = []; + private readonly List _idleSessionIds = []; + private int _nextIndexToPrune; + + private long _currentIdleSessionCount; + + public TimeProvider TimeProvider => _timeProvider; + + public void IncrementIdleSessionCount() => Interlocked.Increment(ref _currentIdleSessionCount); + public void DecrementIdleSessionCount() => Interlocked.Decrement(ref _currentIdleSessionCount); + + public bool TryGetValue(string key, [NotNullWhen(true)] out StreamableHttpSession? value) => _sessions.TryGetValue(key, out value); + public bool TryRemove(string key, [NotNullWhen(true)] out StreamableHttpSession? value) => _sessions.TryRemove(key, out value); + + public async ValueTask StartNewSessionAsync(StreamableHttpSession newSession, CancellationToken cancellationToken) + { + while (!TryAddSessionImmediately(newSession)) + { + StreamableHttpSession? sessionToPrune = null; + + lock (_idlePruningLock) + { + EnsureIdleSessionsSortedUnsynchronized(); + + while (_nextIndexToPrune < _idleSessionIds.Count) + { + var pruneId = _idleSessionIds[_nextIndexToPrune++]; + if (_sessions.TryRemove(pruneId, out sessionToPrune)) + { + LogIdleSessionLimit(pruneId, _maxIdleSessionCount); + break; + } + } + + if (sessionToPrune is null) + { + // If we couldn't find any active idle sessions to dispose, start another full prune to repopulate _idleSessionIds. + PruneIdleSessionsUnsynchronized(); + + if (_idleSessionIds.Count > 0) + { + continue; + } + else + { + // This indicates all idle sessions are in the process of being disposed which should not happen during normal operation. + // Since there are no idle sessions to prune right now, log a critical error and create the new session anyway. + LogTooManyIdleSessionsClosingConcurrently(newSession.Id, _maxIdleSessionCount, Volatile.Read(ref _currentIdleSessionCount)); + AddSession(newSession); + return; + } + } + } + + try + { + // Since we're at or above the maximum idle session count, we're intentionally waiting for the idle session to be disposed + // before adding a new session to the dictionary to ensure sessions not created faster than they're removed. + await DisposeSessionAsync(sessionToPrune); + + // Take one last chance to check if the initialize request was aborted before we incur the cost of managing a new session. + cancellationToken.ThrowIfCancellationRequested(); + AddSession(newSession); + return; + } + catch + { + await newSession.DisposeAsync(); + throw; + } + } + } + + /// + /// Performs a single pass of idle session pruning, removing sessions that exceed the idle timeout + /// or when the maximum idle session count is exceeded. + /// + public async Task PruneIdleSessionsAsync(CancellationToken cancellationToken) + { + lock (_idlePruningLock) + { + PruneIdleSessionsUnsynchronized(); + } + } + + private void PruneIdleSessionsUnsynchronized() + { + var idleActivityCutoff = _idleTimeoutTicks switch + { + < 0 => long.MinValue, + var ticks => _timeProvider.GetTimestamp() - ticks, + }; + + // We clear the lists at the start of pruning rather than the end so we can use them between runs + // to find the most idle sessions to remove one-at-a-time if necessary to make room for new sessions. + _idleTimestamps.Clear(); + _idleSessionIds.Clear(); + _nextIndexToPrune = -1; + + foreach (var (_, session) in _sessions) + { + if (session.IsActive || session.SessionClosed.IsCancellationRequested) + { + // There's a request currently active or the session is already being closed. + continue; + } + + if (session.LastActivityTicks < idleActivityCutoff) + { + LogIdleSessionTimeout(session.Id, _idleTimeout); + RemoveAndCloseSession(session.Id); + continue; + } + + // Add the timestamp and the session + _idleTimestamps.Add(session.LastActivityTicks); + _idleSessionIds.Add(session.Id); + } + + if (_idleTimestamps.Count > _maxIdleSessionCount) + { + // Sort only if the maximum is breached and sort solely by the timestamp. + EnsureIdleSessionsSortedUnsynchronized(); + + var sessionsToPrune = CollectionsMarshal.AsSpan(_idleSessionIds)[..^_maxIdleSessionCount]; + foreach (var id in sessionsToPrune) + { + LogIdleSessionLimit(id, _maxIdleSessionCount); + RemoveAndCloseSession(id); + } + _nextIndexToPrune = _maxIdleSessionCount; + } + } + + private void EnsureIdleSessionsSortedUnsynchronized() + { + if (_nextIndexToPrune > -1) + { + // Already sorted. + return; + } + + var timestamps = CollectionsMarshal.AsSpan(_idleTimestamps); + timestamps.Sort(CollectionsMarshal.AsSpan(_idleSessionIds)); + _nextIndexToPrune = 0; + } + + /// + /// Disposes all sessions in the manager, typically called during graceful shutdown. + /// + public async Task DisposeAllSessionsAsync() + { + List disposeSessionTasks = []; + + foreach (var (sessionKey, _) in _sessions) + { + if (_sessions.TryRemove(sessionKey, out var session)) + { + disposeSessionTasks.Add(DisposeSessionAsync(session)); + } + } + + await Task.WhenAll(disposeSessionTasks); + } + + private bool TryAddSessionImmediately(StreamableHttpSession session) + { + if (Volatile.Read(ref _currentIdleSessionCount) < _maxIdleSessionCount) + { + AddSession(session); + return true; + } + + return false; + } + + private void AddSession(StreamableHttpSession session) + { + if (!_sessions.TryAdd(session.Id, session)) + { + throw new UnreachableException($"Unreachable given good entropy! Session with ID '{session.Id}' has already been created."); + } + } + + private void RemoveAndCloseSession(string sessionId) + { + if (!_sessions.TryRemove(sessionId, out var session)) + { + return; + } + + // Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown. + _ = DisposeSessionAsync(session); + } + + private async Task DisposeSessionAsync(StreamableHttpSession session) + { + try + { + await session.DisposeAsync(); + } + catch (Exception ex) + { + LogSessionDisposeError(session.Id, ex); + } + } + + [LoggerMessage(Level = LogLevel.Information, Message = "IdleTimeout of {IdleTimeout} exceeded. Closing idle session {SessionId}.")] + private partial void LogIdleSessionTimeout(string sessionId, TimeSpan idleTimeout); + + [LoggerMessage(Level = LogLevel.Information, Message = "MaxIdleSessionCount of {MaxIdleSessionCount} exceeded. Closing idle session {SessionId} despite it being active more recently than the configured IdleTimeout to make room for new sessions.")] + private partial void LogIdleSessionLimit(string sessionId, int maxIdleSessionCount); + + [LoggerMessage(Level = LogLevel.Error, Message = "Error disposing session {SessionId}.")] + private partial void LogSessionDisposeError(string sessionId, Exception ex); + + [LoggerMessage(Level = LogLevel.Critical, Message = "MaxIdleSessionCount of {MaxIdleSessionCount} exceeded, and {CurrentIdleSessionCount} sessions are currently in the process of closing. Creating new session {SessionId} anyway.")] + private partial void LogTooManyIdleSessionsClosingConcurrently(string sessionId, int maxIdleSessionCount, long currentIdleSessionCount); +} diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 6dac1c3e..bfbd805d 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -8,8 +8,6 @@ using ModelContextProtocol.AspNetCore.Stateless; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using System.Collections.Concurrent; -using System.Diagnostics; using System.IO.Pipelines; using System.Security.Claims; using System.Security.Cryptography; @@ -22,6 +20,7 @@ internal sealed class StreamableHttpHandler( IOptions mcpServerOptionsSnapshot, IOptionsFactory mcpServerOptionsFactory, IOptions httpServerTransportOptions, + StatefulSessionManager sessionManager, IDataProtectionProvider dataProtection, ILoggerFactory loggerFactory, IServiceProvider applicationServices) @@ -29,8 +28,6 @@ internal sealed class StreamableHttpHandler( private const string McpSessionIdHeaderName = "Mcp-Session-Id"; private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); - public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal); - public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; private IDataProtector Protector { get; } = dataProtection.CreateProtector("Microsoft.AspNetCore.StreamableHttpHandler.StatelessSessionId"); @@ -56,28 +53,15 @@ await WriteJsonRpcErrorAsync(context, return; } - try - { - using var _ = session.AcquireReference(); + await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); - InitializeSseResponse(context); - var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); - if (!wroteResponse) - { - // We wound up writing nothing, so there should be no Content-Type response header. - context.Response.Headers.ContentType = (string?)null; - context.Response.StatusCode = StatusCodes.Status202Accepted; - } - } - finally + InitializeSseResponse(context); + var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); + if (!wroteResponse) { - // Stateless sessions are 1:1 with HTTP requests and are outlived by the MCP session tracked by the Mcp-Session-Id. - // Non-stateless sessions are 1:1 with the Mcp-Session-Id and outlive the POST request. - // Non-stateless sessions get disposed by a DELETE request or the IdleTrackingBackgroundService. - if (HttpServerTransportOptions.Stateless) - { - await session.DisposeAsync(); - } + // We wound up writing nothing, so there should be no Content-Type response header. + context.Response.Headers.ContentType = (string?)null; + context.Response.StatusCode = StatusCodes.Status202Accepted; } } @@ -106,7 +90,7 @@ await WriteJsonRpcErrorAsync(context, return; } - using var _ = session.AcquireReference(); + await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); InitializeSseResponse(context); // We should flush headers to indicate a 200 success quickly, because the initialization response @@ -119,17 +103,22 @@ await WriteJsonRpcErrorAsync(context, public async Task HandleDeleteRequestAsync(HttpContext context) { var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString(); - if (Sessions.TryRemove(sessionId, out var session)) + if (sessionManager.TryRemove(sessionId, out var session)) { await session.DisposeAsync(); } } - private async ValueTask?> GetSessionAsync(HttpContext context, string sessionId) + private async ValueTask GetSessionAsync(HttpContext context, string sessionId) { - HttpMcpSession? session; + StreamableHttpSession? session; - if (HttpServerTransportOptions.Stateless) + if (string.IsNullOrEmpty(sessionId)) + { + await WriteJsonRpcErrorAsync(context, "Bad Request: Mcp-Session-Id header is required", StatusCodes.Status400BadRequest); + return null; + } + else if (HttpServerTransportOptions.Stateless) { var sessionJson = Protector.Unprotect(sessionId); var statelessSessionId = JsonSerializer.Deserialize(sessionJson, StatelessSessionIdJsonContext.Default.StatelessSessionId); @@ -140,7 +129,7 @@ public async Task HandleDeleteRequestAsync(HttpContext context) }; session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId); } - else if (!Sessions.TryGetValue(sessionId, out session)) + else if (!sessionManager.TryGetValue(sessionId, out session)) { // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this @@ -163,7 +152,7 @@ await WriteJsonRpcErrorAsync(context, return session; } - private async ValueTask?> GetOrCreateSessionAsync(HttpContext context) + private async ValueTask GetOrCreateSessionAsync(HttpContext context) { var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString(); @@ -177,7 +166,7 @@ await WriteJsonRpcErrorAsync(context, } } - private async ValueTask> StartNewSessionAsync(HttpContext context) + private async ValueTask StartNewSessionAsync(HttpContext context) { string sessionId; StreamableHttpServerTransport transport; @@ -204,21 +193,10 @@ private async ValueTask> StartNewS ScheduleStatelessSessionIdWrite(context, transport); } - var session = await CreateSessionAsync(context, transport, sessionId); - - // The HttpMcpSession is not stored between requests in stateless mode. Instead, the session is recreated from the MCP-Session-Id. - if (!HttpServerTransportOptions.Stateless) - { - if (!Sessions.TryAdd(sessionId, session)) - { - throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); - } - } - - return session; + return await CreateSessionAsync(context, transport, sessionId); } - private async ValueTask> CreateSessionAsync( + private async ValueTask CreateSessionAsync( HttpContext context, StreamableHttpServerTransport transport, string sessionId, @@ -248,10 +226,7 @@ private async ValueTask> CreateSes context.Features.Set(server); var userIdClaim = statelessId?.UserIdClaim ?? GetUserIdClaim(context.User); - var session = new HttpMcpSession(sessionId, transport, userIdClaim, HttpServerTransportOptions.TimeProvider) - { - Server = server, - }; + var session = new StreamableHttpSession(sessionId, transport, server, userIdClaim, sessionManager); var runSessionAsync = HttpServerTransportOptions.RunSessionHandler ?? RunSessionAsync; session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed); diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs new file mode 100644 index 00000000..ffeafada --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs @@ -0,0 +1,164 @@ +using ModelContextProtocol.Server; +using System.Diagnostics; +using System.Security.Claims; + +namespace ModelContextProtocol.AspNetCore; + +internal sealed class StreamableHttpSession( + string sessionId, + StreamableHttpServerTransport transport, + IMcpServer server, + UserIdClaim? userId, + StatefulSessionManager sessionManager) : IAsyncDisposable +{ + private int _referenceCount; + private SessionState _state; + private readonly object _stateLock = new(); + + private int _getRequestStarted; + private readonly CancellationTokenSource _disposeCts = new(); + + public string Id => sessionId; + public StreamableHttpServerTransport Transport => transport; + public IMcpServer Server => server; + private StatefulSessionManager SessionManager => sessionManager; + + public CancellationToken SessionClosed => _disposeCts.Token; + public bool IsActive => !SessionClosed.IsCancellationRequested && _referenceCount > 0; + public long LastActivityTicks { get; private set; } = sessionManager.TimeProvider.GetTimestamp(); + + public Task ServerRunTask { get; set; } = Task.CompletedTask; + + public async ValueTask AcquireReferenceAsync(CancellationToken cancellationToken) + { + // The StreamableHttpSession is not stored between requests in stateless mode. Instead, the session is recreated from the MCP-Session-Id. + // Stateless sessions are 1:1 with HTTP requests and are outlived by the MCP session tracked by the Mcp-Session-Id. + // Non-stateless sessions are 1:1 with the Mcp-Session-Id and outlive the POST request. + // Non-stateless sessions get disposed by a DELETE request or the IdleTrackingBackgroundService. + if (transport.Stateless) + { + return this; + } + + SessionState startingState; + + lock (_stateLock) + { + startingState = _state; + _referenceCount++; + + switch (startingState) + { + case SessionState.Uninitialized: + Debug.Assert(_referenceCount == 1, "The _referenceCount should start at 1 when the StreamableHttpSession is uninitialized."); + _state = SessionState.Started; + break; + case SessionState.Started: + if (_referenceCount == 1) + { + sessionManager.DecrementIdleSessionCount(); + } + break; + case SessionState.Disposed: + throw new ObjectDisposedException(nameof(StreamableHttpSession)); + } + } + + if (startingState == SessionState.Uninitialized) + { + await sessionManager.StartNewSessionAsync(this, cancellationToken); + } + + return new UnreferenceDisposable(this); + } + + public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0; + public bool HasSameUserId(ClaimsPrincipal user) => userId == StreamableHttpHandler.GetUserIdClaim(user); + + public async ValueTask DisposeAsync() + { + var wasIdle = false; + + lock (_stateLock) + { + switch (_state) + { + case SessionState.Uninitialized: + break; + case SessionState.Started: + if (_referenceCount == 0) + { + wasIdle = true; + } + break; + case SessionState.Disposed: + return; + } + + _state = SessionState.Disposed; + } + + try + { + await _disposeCts.CancelAsync(); + + try + { + await ServerRunTask; + } + finally + { + await DisposeServerThenTransportAsync(); + } + } + catch (OperationCanceledException) + { + } + finally + { + if (wasIdle) + { + sessionManager.DecrementIdleSessionCount(); + } + _disposeCts.Dispose(); + } + } + + private async ValueTask DisposeServerThenTransportAsync() + { + try + { + await server.DisposeAsync(); + } + finally + { + await transport.DisposeAsync(); + } + } + + private sealed class UnreferenceDisposable(StreamableHttpSession session) : IAsyncDisposable + { + public ValueTask DisposeAsync() + { + lock (session._stateLock) + { + Debug.Assert(session._state != SessionState.Uninitialized, "The session should have been initialized."); + if (session._state != SessionState.Disposed && --session._referenceCount == 0) + { + var sessionManager = session.SessionManager; + session.LastActivityTicks = sessionManager.TimeProvider.GetTimestamp(); + sessionManager.IncrementIdleSessionCount(); + } + } + + return default; + } + } + + private enum SessionState + { + Uninitialized, + Started, + Disposed + } +} diff --git a/src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs b/src/ModelContextProtocol.AspNetCore/UserIdClaim.cs similarity index 58% rename from src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs rename to src/ModelContextProtocol.AspNetCore/UserIdClaim.cs index f18c1c5f..5b5951d3 100644 --- a/src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs +++ b/src/ModelContextProtocol.AspNetCore/UserIdClaim.cs @@ -1,3 +1,3 @@ -namespace ModelContextProtocol.AspNetCore.Stateless; +namespace ModelContextProtocol.AspNetCore; internal sealed record UserIdClaim(string Type, string Value, string Issuer); diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index aba7bbcf..479a7627 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -193,6 +193,8 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation return; } + LogTransportReceivedMessageSensitive(Name, data); + try { var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 190bec0b..c4014ed7 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -211,6 +211,8 @@ private async Task ReceiveUnsolicitedMessagesAsync() private async Task ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) { + LogTransportReceivedMessageSensitive(Name, data); + try { var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); diff --git a/tests/Common/Utils/MockLoggerProvider.cs b/tests/Common/Utils/MockLoggerProvider.cs index f5264edc..14a0f401 100644 --- a/tests/Common/Utils/MockLoggerProvider.cs +++ b/tests/Common/Utils/MockLoggerProvider.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Tests.Utils; public class MockLoggerProvider() : ILoggerProvider { - public ConcurrentQueue<(string Category, LogLevel LogLevel, string Message, Exception? Exception)> LogMessages { get; } = []; + public ConcurrentQueue<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception)> LogMessages { get; } = []; public ILogger CreateLogger(string categoryName) { @@ -21,7 +21,7 @@ private class MockLogger(MockLoggerProvider mockProvider, string category) : ILo public void Log( LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) { - mockProvider.LogMessages.Enqueue((category, logLevel, formatter(state, exception), exception)); + mockProvider.LogMessages.Enqueue((category, logLevel, eventId, formatter(state, exception), exception)); } public bool IsEnabled(LogLevel logLevel) => true; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index 0b3ae4c2..bb184034 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -505,7 +505,6 @@ public async Task IdleSessionsPastMaxIdleSessionCount_ArePruned_LongestIdleFirst Assert.NotEqual(secondSessionId, thirdSessionId); // Pruning of the second session results in a 404 since we used the first session more recently. - fakeTimeProvider.Advance(TimeSpan.FromSeconds(10)); SetSessionId(secondSessionId); using var response = await HttpClient.PostAsync("", JsonContent(EchoRequest), TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -517,8 +516,9 @@ public async Task IdleSessionsPastMaxIdleSessionCount_ArePruned_LongestIdleFirst SetSessionId(thirdSessionId); await CallEchoAndValidateAsync(); - var logMessage = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Critical); - Assert.StartsWith("Exceeded maximum of 2 idle sessions.", logMessage.Message); + var idleLimitLogMessage = Assert.Single(mockLoggerProvider.LogMessages, m => m.EventId.Name == "LogIdleSessionLimit"); + Assert.Equal(LogLevel.Information, idleLimitLogMessage.LogLevel); + Assert.StartsWith("MaxIdleSessionCount of 2 exceeded. Closing idle session", idleLimitLogMessage.Message); } private static StringContent JsonContent(string json) => new StringContent(json, Encoding.UTF8, "application/json");