Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder
builder.Services.TryAddSingleton<StreamableHttpHandler>();
builder.Services.TryAddSingleton<SseHandler>();
builder.Services.AddHostedService<IdleTrackingBackgroundService>();
builder.Services.AddDataProtection();

if (configureOptions is not null)
{
Expand Down
30 changes: 7 additions & 23 deletions src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

namespace ModelContextProtocol.AspNetCore;

internal sealed class HttpMcpSession<TTransport>(string sessionId, TTransport transport, ClaimsPrincipal user, TimeProvider timeProvider) : IAsyncDisposable
internal sealed class HttpMcpSession<TTransport>(
string sessionId,
TTransport transport,
(string Type, string Value, string Issuer)? userIdClaim,
TimeProvider timeProvider) : IAsyncDisposable
where TTransport : ITransport
{
private int _referenceCount;
Expand All @@ -13,7 +17,7 @@ internal sealed class HttpMcpSession<TTransport>(string sessionId, TTransport tr

public string Id { get; } = sessionId;
public TTransport Transport { get; } = transport;
public (string Type, string Value, string Issuer)? UserIdClaim { get; } = GetUserIdClaim(user);
public (string Type, string Value, string Issuer)? UserIdClaim { get; } = userIdClaim;

public CancellationToken SessionClosed => _disposeCts.Token;

Expand Down Expand Up @@ -63,27 +67,7 @@ public async ValueTask DisposeAsync()
}

public bool HasSameUserId(ClaimsPrincipal user)
=> UserIdClaim == GetUserIdClaim(user);

// SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims.
// However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than
// verifying antiforgery tokens from <form> posts.
private static (string Type, string Value, string Issuer)? GetUserIdClaim(ClaimsPrincipal user)
{
if (user?.Identity?.IsAuthenticated != true)
{
return null;
}

var claim = user.FindFirst(ClaimTypes.NameIdentifier) ?? user.FindFirst("sub") ?? user.FindFirst(ClaimTypes.Upn);

if (claim is { } idClaim)
{
return (idClaim.Type, idClaim.Value, idClaim.Issuer);
}

return null;
}
=> UserIdClaim == StreamableHttpHandler.GetUserIdClaim(user);

private sealed class UnreferenceDisposable(HttpMcpSession<TTransport> session, TimeProvider timeProvider) : IDisposable
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ public class HttpServerTransportOptions
/// </summary>
public Func<HttpContext, IMcpServer, CancellationToken, Task>? RunSessionHandler { get; set; }

/// <summary>
/// Gets or sets whether the server should run in a stateless mode that does not require all requests for a given session
/// to arrive to the same ASP.NET Core application process. If true, the /sse endpoint will be disabled, and
/// client capabilities will be round-tripped as part of the mcp-session-id header instead of stored in memory. Defaults to false.
/// </summary>
public bool Stateless { get; set; }

/// <summary>
/// Represents the duration of time the server will wait between any active requests before timing out an
/// MCP session. This is checked in background every 5 seconds. A client trying to resume a session will
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,27 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
.WithMetadata(new AcceptsMetadata(["application/json"]))
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]))
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted));
streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync)
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]));
streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync);

// Map legacy HTTP with SSE endpoints.
var sseHandler = endpoints.ServiceProvider.GetRequiredService<SseHandler>();
var sseGroup = mcpGroup.MapGroup("")
.WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}");

sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync)
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]));
sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync)
.WithMetadata(new AcceptsMetadata(["application/json"]))
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted));

if (!streamableHttpHandler.HttpServerTransportOptions.Stateless)
{
// The GET and DELETE endpoints are not mapped in Stateless mode since there's no way to send unsolicited messages
// for the GET to handle, and there is no server-side state for the DELETE to clean up.
streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync)
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]));
streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync);

// Map legacy HTTP with SSE endpoints only if not in Stateless mode, because we cannot guarantee the /message requests
// will be handled by the same process as the /sse request.
var sseHandler = endpoints.ServiceProvider.GetRequiredService<SseHandler>();
var sseGroup = mcpGroup.MapGroup("")
.WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}");

sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync)
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]));
sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync)
.WithMetadata(new AcceptsMetadata(["application/json"]))
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted));
}

return mcpGroup;
}
Expand Down
5 changes: 4 additions & 1 deletion src/ModelContextProtocol.AspNetCore/SseHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ public async Task HandleSseRequestAsync(HttpContext context)
var requestPath = (context.Request.PathBase + context.Request.Path).ToString();
var endpointPattern = requestPath[..(requestPath.LastIndexOf('/') + 1)];
await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}");
await using var httpMcpSession = new HttpMcpSession<SseResponseStreamTransport>(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider);

var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User);
await using var httpMcpSession = new HttpMcpSession<SseResponseStreamTransport>(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider);

if (!_sessions.TryAdd(sessionId, httpMcpSession))
{
throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
Expand Down
16 changes: 16 additions & 0 deletions src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using ModelContextProtocol.Protocol.Types;
using System.Text.Json.Serialization;

namespace ModelContextProtocol.AspNetCore;

internal class StatelessSessionId
{
[JsonPropertyName("capabilities")]
public ClientCapabilities? Capabilities { get; init; }

[JsonPropertyName("clientInfo")]
public Implementation? ClientInfo { get; init; }

[JsonPropertyName("userIdClaim")]
public (string Type, string Value, string Issuer)? UserIdClaim { get; init; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using System.Text.Json.Serialization;

namespace ModelContextProtocol.AspNetCore;

[JsonSerializable(typeof(StatelessSessionId))]
internal sealed partial class StatelessSessionIdJsonContext : JsonSerializerContext;
Loading
Loading