Skip to content

Prune idle sessions before starting new ones #701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions samples/ProtectedMcpServer/Tools/HttpClientExt.cs

This file was deleted.

19 changes: 11 additions & 8 deletions samples/ProtectedMcpServer/Tools/WeatherTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ public async Task<string> GetAlerts(
[Description("The US state to get alerts for. Use the 2 letter abbreviation for the state (e.g. NY).")] string state)
{
var client = _httpClientFactory.CreateClient("WeatherApi");
using var jsonDocument = await client.ReadJsonDocumentAsync($"/alerts/active/area/{state}");
var jsonElement = jsonDocument.RootElement;
var alerts = jsonElement.GetProperty("features").EnumerateArray();
using var jsonDocument = await client.GetFromJsonAsync<JsonDocument>($"/alerts/active/area/{state}")
?? throw new McpException("No JSON returned from alerts endpoint");

var alerts = jsonDocument.RootElement.GetProperty("features").EnumerateArray();

if (!alerts.Any())
{
Expand All @@ -50,12 +51,14 @@ public async Task<string> GetForecast(
{
var client = _httpClientFactory.CreateClient("WeatherApi");
var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}");
using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl);
var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString()
?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}");

using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl);
var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray();
using var locationDocument = await client.GetFromJsonAsync<JsonDocument>(pointUrl);
var forecastUrl = locationDocument?.RootElement.GetProperty("properties").GetProperty("forecast").GetString()
?? throw new McpException($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}");

using var forecastDocument = await client.GetFromJsonAsync<JsonDocument>(forecastUrl);
var periods = forecastDocument?.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray()
?? throw new McpException("No JSON returned from forecast endpoint");

return string.Join("\n---\n", periods.Select(period => $"""
{period.GetProperty("name").GetString()}
Expand Down
6 changes: 3 additions & 3 deletions samples/QuickstartWeatherServer/Tools/WeatherTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ public static async Task<string> GetForecast(
[Description("Longitude of the location.")] double longitude)
{
var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}");
using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl);
var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString()
?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}");
using var locationDocument = await client.ReadJsonDocumentAsync(pointUrl);
var forecastUrl = locationDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString()
?? throw new McpException($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}");

using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl);
var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder
{
ArgumentNullException.ThrowIfNull(builder);

builder.Services.TryAddSingleton<StatefulSessionManager>();
builder.Services.TryAddSingleton<StreamableHttpHandler>();
builder.Services.TryAddSingleton<SseHandler>();
builder.Services.AddHostedService<IdleTrackingBackgroundService>();
Expand Down
85 changes: 0 additions & 85 deletions src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public class HttpServerTransportOptions
/// keeping a GET request open will not count towards this limit.
/// Defaults to 100,000 sessions.
/// </remarks>
public int MaxIdleSessionCount { get; set; } = 100_000;
public int MaxIdleSessionCount { get; set; } = 10_000;

/// <summary>
/// Used for testing the <see cref="IdleTimeout"/>.
Expand Down
113 changes: 6 additions & 107 deletions src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
using System.Runtime.InteropServices;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Server;

namespace ModelContextProtocol.AspNetCore;

internal sealed partial class IdleTrackingBackgroundService(
StreamableHttpHandler handler,
StatefulSessionManager sessions,
IOptions<HttpServerTransportOptions> options,
IHostApplicationLifetime appLifetime,
ILogger<IdleTrackingBackgroundService> logger) : BackgroundService
{
// The compiler will complain about the parameter being unused otherwise despite the source generator.
// Workaround for https://github.com/dotnet/runtime/issues/91121. This is fixed in .NET 9 and later.
private readonly ILogger _logger = logger;

protected override async Task ExecuteAsync(CancellationToken stoppingToken)
Expand All @@ -30,65 +28,9 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
var timeProvider = options.Value.TimeProvider;
using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider);

var idleTimeoutTicks = options.Value.IdleTimeout.Ticks;
var maxIdleSessionCount = options.Value.MaxIdleSessionCount;

// Create two lists that will be reused between runs.
// This assumes that the number of idle sessions is not breached frequently.
// If the idle sessions often breach the maximum, a priority queue could be considered.
var idleSessionsTimestamps = new List<long>();
var idleSessionSessionIds = new List<string>();

while (!stoppingToken.IsCancellationRequested && await timer.WaitForNextTickAsync(stoppingToken))
{
var idleActivityCutoff = idleTimeoutTicks switch
{
< 0 => long.MinValue,
var ticks => timeProvider.GetTimestamp() - ticks,
};

foreach (var (_, session) in handler.Sessions)
{
if (session.IsActive || session.SessionClosed.IsCancellationRequested)
{
// There's a request currently active or the session is already being closed.
continue;
}

if (session.LastActivityTicks < idleActivityCutoff)
{
RemoveAndCloseSession(session.Id);
continue;
}

// Add the timestamp and the session
idleSessionsTimestamps.Add(session.LastActivityTicks);
idleSessionSessionIds.Add(session.Id);

// Emit critical log at most once every 5 seconds the idle count it exceeded,
// since the IdleTimeout will no longer be respected.
if (idleSessionsTimestamps.Count == maxIdleSessionCount + 1)
{
LogMaxSessionIdleCountExceeded(maxIdleSessionCount);
}
}

if (idleSessionsTimestamps.Count > maxIdleSessionCount)
{
var timestamps = CollectionsMarshal.AsSpan(idleSessionsTimestamps);

// Sort only if the maximum is breached and sort solely by the timestamp. Sort both collections.
timestamps.Sort(CollectionsMarshal.AsSpan(idleSessionSessionIds));

var sessionsToPrune = CollectionsMarshal.AsSpan(idleSessionSessionIds)[..^maxIdleSessionCount];
foreach (var id in sessionsToPrune)
{
RemoveAndCloseSession(id);
}
}

idleSessionsTimestamps.Clear();
idleSessionSessionIds.Clear();
await sessions.PruneIdleSessionsAsync(stoppingToken);
}
}
catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested)
Expand All @@ -98,64 +40,21 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
try
{
List<Task> disposeSessionTasks = [];

foreach (var (sessionKey, _) in handler.Sessions)
{
if (handler.Sessions.TryRemove(sessionKey, out var session))
{
disposeSessionTasks.Add(DisposeSessionAsync(session));
}
}

await Task.WhenAll(disposeSessionTasks);
await sessions.DisposeAllSessionsAsync();
}
finally
{
if (!stoppingToken.IsCancellationRequested)
{
// Something went terribly wrong. A very unexpected exception must be bubbling up, but let's ensure we also stop the application,
// so that it hopefully gets looked at and restarted. This shouldn't really be reachable.
appLifetime.StopApplication();
IdleTrackingBackgroundServiceStoppedUnexpectedly();
appLifetime.StopApplication();
}
}
}
}

private void RemoveAndCloseSession(string sessionId)
{
if (!handler.Sessions.TryRemove(sessionId, out var session))
{
return;
}

LogSessionIdle(session.Id);
// Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown.
_ = DisposeSessionAsync(session);
}

private async Task DisposeSessionAsync(HttpMcpSession<StreamableHttpServerTransport> session)
{
try
{
await session.DisposeAsync();
}
catch (Exception ex)
{
LogSessionDisposeError(session.Id, ex);
}
}

[LoggerMessage(Level = LogLevel.Information, Message = "Closing idle session {sessionId}.")]
private partial void LogSessionIdle(string sessionId);

[LoggerMessage(Level = LogLevel.Error, Message = "Error disposing session {sessionId}.")]
private partial void LogSessionDisposeError(string sessionId, Exception ex);

[LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded maximum of {maxIdleSessionCount} idle sessions. Now closing sessions active more recently than configured IdleTimeout.")]
private partial void LogMaxSessionIdleCountExceeded(int maxIdleSessionCount);

[LoggerMessage(Level = LogLevel.Critical, Message = "The IdleTrackingBackgroundService has stopped unexpectedly.")]
private partial void IdleTrackingBackgroundServiceStoppedUnexpectedly();
}
18 changes: 9 additions & 9 deletions src/ModelContextProtocol.AspNetCore/SseHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ internal sealed class SseHandler(
IHostApplicationLifetime hostApplicationLifetime,
ILoggerFactory loggerFactory)
{
private readonly ConcurrentDictionary<string, HttpMcpSession<SseResponseStreamTransport>> _sessions = new(StringComparer.Ordinal);
private readonly ConcurrentDictionary<string, SseSession> _sessions = new(StringComparer.Ordinal);

public async Task HandleSseRequestAsync(HttpContext context)
{
Expand All @@ -34,9 +34,9 @@ public async Task HandleSseRequestAsync(HttpContext context)
await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}", sessionId);

var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User);
await using var httpMcpSession = new HttpMcpSession<SseResponseStreamTransport>(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider);
var sseSession = new SseSession(transport, userIdClaim);

if (!_sessions.TryAdd(sessionId, httpMcpSession))
if (!_sessions.TryAdd(sessionId, sseSession))
{
throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}
Expand All @@ -55,12 +55,10 @@ public async Task HandleSseRequestAsync(HttpContext context)
try
{
await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices);
httpMcpSession.Server = mcpServer;
context.Features.Set(mcpServer);

var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? StreamableHttpHandler.RunSessionAsync;
httpMcpSession.ServerRunTask = runSessionAsync(context, mcpServer, cancellationToken);
await httpMcpSession.ServerRunTask;
await runSessionAsync(context, mcpServer, cancellationToken);
}
finally
{
Expand All @@ -87,13 +85,13 @@ public async Task HandleMessageRequestAsync(HttpContext context)
return;
}

if (!_sessions.TryGetValue(sessionId.ToString(), out var httpMcpSession))
if (!_sessions.TryGetValue(sessionId.ToString(), out var sseSession))
{
await Results.BadRequest($"Session ID not found.").ExecuteAsync(context);
return;
}

if (!httpMcpSession.HasSameUserId(context.User))
if (sseSession.UserId != StreamableHttpHandler.GetUserIdClaim(context.User))
{
await Results.Forbid().ExecuteAsync(context);
return;
Expand All @@ -106,8 +104,10 @@ public async Task HandleMessageRequestAsync(HttpContext context)
return;
}

await httpMcpSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted);
await sseSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted);
context.Response.StatusCode = StatusCodes.Status202Accepted;
await context.Response.WriteAsync("Accepted");
}

private record SseSession(SseResponseStreamTransport Transport, UserIdClaim? UserId);
}
Loading
Loading