diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs index 1456ce56..ee066707 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -1,7 +1,9 @@ using ModelContextProtocol.AspNetCore.Stateless; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Diagnostics; using System.Security.Claims; +using System.Threading; namespace ModelContextProtocol.AspNetCore; @@ -9,12 +11,17 @@ internal sealed class HttpMcpSession( string sessionId, TTransport transport, UserIdClaim? userId, - TimeProvider timeProvider) : IAsyncDisposable + TimeProvider timeProvider, + SemaphoreSlim? idleSessionSemaphore = null) : IAsyncDisposable where TTransport : ITransport { private int _referenceCount; private int _getRequestStarted; - private CancellationTokenSource _disposeCts = new(); + private bool _isDisposed; + + private readonly SemaphoreSlim? _idleSessionSemaphore = idleSessionSemaphore; + private readonly CancellationTokenSource _disposeCts = new(); + private readonly object _referenceCountLock = new(); public string Id { get; } = sessionId; public TTransport Transport { get; } = transport; @@ -30,9 +37,23 @@ internal sealed class HttpMcpSession( public IMcpServer? Server { get; set; } public Task? ServerRunTask { get; set; } - public IDisposable AcquireReference() + public IAsyncDisposable AcquireReference() { - Interlocked.Increment(ref _referenceCount); + // We don't do idle tracking for stateless sessions, so we don't need to acquire a reference. + if (_idleSessionSemaphore is null) + { + return new NoopDisposable(); + } + + lock (_referenceCountLock) + { + if (!_isDisposed && ++_referenceCount == 1) + { + // Non-idle sessions should not prevent session creation. + _idleSessionSemaphore.Release(); + } + } + return new UnreferenceDisposable(this); } @@ -40,6 +61,19 @@ public IDisposable AcquireReference() public async ValueTask DisposeAsync() { + bool shouldReleaseIdleSessionSemaphore; + + lock (_referenceCountLock) + { + if (_isDisposed) + { + return; + } + + _isDisposed = true; + shouldReleaseIdleSessionSemaphore = _referenceCount == 0; + } + try { await _disposeCts.CancelAsync(); @@ -65,21 +99,45 @@ public async ValueTask DisposeAsync() { await Transport.DisposeAsync(); _disposeCts.Dispose(); + + // If the session was disposed while it was inactive, we need to release the semaphore + // to allow new sessions to be created. + if (_idleSessionSemaphore is not null && shouldReleaseIdleSessionSemaphore) + { + _idleSessionSemaphore.Release(); + } } } } - public bool HasSameUserId(ClaimsPrincipal user) - => UserIdClaim == StreamableHttpHandler.GetUserIdClaim(user); + public bool HasSameUserId(ClaimsPrincipal user) => UserIdClaim == StreamableHttpHandler.GetUserIdClaim(user); - private sealed class UnreferenceDisposable(HttpMcpSession session) : IDisposable + private sealed class UnreferenceDisposable(HttpMcpSession session) : IAsyncDisposable { - public void Dispose() + public async ValueTask DisposeAsync() { - if (Interlocked.Decrement(ref session._referenceCount) == 0) + Debug.Assert(session._idleSessionSemaphore is not null, "Only StreamableHttpHandler should call AcquireReference."); + + bool shouldMarkSessionIdle; + + lock (session._referenceCountLock) + { + shouldMarkSessionIdle = !session._isDisposed && --session._referenceCount == 0; + } + + if (shouldMarkSessionIdle) { session.LastActivityTicks = session.TimeProvider.GetTimestamp(); + + // Acquire semaphore when session becomes inactive (reference count goes to 0) to slow + // down session creation until idle sessions are disposed by the background service. + await session._idleSessionSemaphore.WaitAsync(); } } } + + private sealed class NoopDisposable : IAsyncDisposable + { + public ValueTask DisposeAsync() => ValueTask.CompletedTask; + } } diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 2a34a17a..94de9cb9 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -66,9 +66,9 @@ public class HttpServerTransportOptions /// Past this limit, the server will log a critical error and terminate the oldest idle sessions even if they have not reached /// their until the idle session count is below this limit. Clients that keep their session open by /// keeping a GET request open will not count towards this limit. - /// Defaults to 100,000 sessions. + /// Defaults to 10,000 sessions. /// - public int MaxIdleSessionCount { get; set; } = 100_000; + public int MaxIdleSessionCount { get; set; } = 10_000; /// /// Used for testing the . diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 6dac1c3e..2231b7ce 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -31,6 +31,10 @@ internal sealed class StreamableHttpHandler( public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal); + // Semaphore to control session creation backpressure when there are too many idle sessions + // Initial and max count is 10% more than MaxIdleSessionCount (or 100 more if that's higher) + private readonly SemaphoreSlim _idleSessionSemaphore = CreateIdleSessionSemaphore(httpServerTransportOptions.Value); + public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; private IDataProtector Protector { get; } = dataProtection.CreateProtector("Microsoft.AspNetCore.StreamableHttpHandler.StatelessSessionId"); @@ -58,7 +62,7 @@ await WriteJsonRpcErrorAsync(context, try { - using var _ = session.AcquireReference(); + await using var _ = session.AcquireReference(); InitializeSseResponse(context); var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); @@ -106,7 +110,7 @@ await WriteJsonRpcErrorAsync(context, return; } - using var _ = session.AcquireReference(); + await using var _ = session.AcquireReference(); InitializeSseResponse(context); // We should flush headers to indicate a 200 success quickly, because the initialization response @@ -184,6 +188,11 @@ private async ValueTask> StartNewS if (!HttpServerTransportOptions.Stateless) { + // Acquire semaphore before creating stateful sessions to create backpressure. + // This semaphore represents "slots" for idle sessions, and we may need to wait on the + // IdleTrackingBackgroundService to dispose idle sessions before continuing. + await _idleSessionSemaphore.WaitAsync(context.RequestAborted); + sessionId = MakeNewSessionId(); transport = new() { @@ -248,7 +257,8 @@ private async ValueTask> CreateSes context.Features.Set(server); var userIdClaim = statelessId?.UserIdClaim ?? GetUserIdClaim(context.User); - var session = new HttpMcpSession(sessionId, transport, userIdClaim, HttpServerTransportOptions.TimeProvider) + var semaphore = HttpServerTransportOptions.Stateless ? null : _idleSessionSemaphore; + var session = new HttpMcpSession(sessionId, transport, userIdClaim, HttpServerTransportOptions.TimeProvider, semaphore) { Server = server, }; @@ -337,6 +347,13 @@ private static bool MatchesApplicationJsonMediaType(MediaTypeHeaderValue acceptH private static bool MatchesTextEventStreamMediaType(MediaTypeHeaderValue acceptHeaderValue) => acceptHeaderValue.MatchesMediaType("text/event-stream"); + private static SemaphoreSlim CreateIdleSessionSemaphore(HttpServerTransportOptions options) + { + var maxIdleSessionCount = options.MaxIdleSessionCount; + var semaphoreCount = Math.Max(maxIdleSessionCount + 100, (int)(maxIdleSessionCount * 1.1)); + return new SemaphoreSlim(semaphoreCount, semaphoreCount); + } + private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe { public PipeReader Input => context.Request.BodyReader; diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSession.cs index 06b2894b..064dd39a 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSession.cs @@ -736,7 +736,7 @@ private static TimeSpan GetElapsed(long startingTimestamp) => [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} method '{Method}' request handler failed.")] private partial void LogRequestHandlerException(string endpointName, string method, Exception exception); - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")] + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received message for unknown request ID '{RequestId}'.")] private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId); [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")]