Skip to content

Commit d5ffe1f

Browse files
Merge branch 'main' of https://github.com/Tyler-R-Kendrick/mcp-csharp-sdk into cancellation-enhancements
2 parents 3e3f95d + ac0bc3a commit d5ffe1f

39 files changed

+1456
-894
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,8 @@ nCrunchTemp_*
7676

7777
# Auto-generated documentation
7878
docs/_site
79-
docs/api
79+
docs/api
80+
81+
# Rider
82+
.idea/
83+
.idea_modules/

Directory.Packages.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
<PackageVersion Include="Serilog.Sinks.Debug" Version="3.0.0" />
6565
<PackageVersion Include="Serilog.Sinks.File" Version="6.0.0" />
6666
<PackageVersion Include="System.Linq.AsyncEnumerable" Version="$(System10Version)" />
67-
<PackageVersion Include="xunit.v3" Version="1.1.0" />
67+
<PackageVersion Include="xunit.v3" Version="2.0.1" />
6868
<PackageVersion Include="xunit.runner.visualstudio" Version="3.0.2" />
6969
</ItemGroup>
7070
</Project>

src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
using Microsoft.AspNetCore.Http;
22
using Microsoft.AspNetCore.Http.Features;
33
using Microsoft.AspNetCore.Routing;
4+
using Microsoft.AspNetCore.Routing.Patterns;
45
using Microsoft.AspNetCore.WebUtilities;
56
using Microsoft.Extensions.DependencyInjection;
7+
using Microsoft.Extensions.Hosting;
68
using Microsoft.Extensions.Logging;
79
using Microsoft.Extensions.Options;
810
using ModelContextProtocol.Protocol.Messages;
911
using ModelContextProtocol.Protocol.Transport;
1012
using ModelContextProtocol.Server;
1113
using ModelContextProtocol.Utils.Json;
1214
using System.Collections.Concurrent;
15+
using System.Diagnostics.CodeAnalysis;
1316
using System.Security.Cryptography;
1417

1518
namespace Microsoft.AspNetCore.Builder;
@@ -23,53 +26,87 @@ public static class McpEndpointRouteBuilderExtensions
2326
/// Sets up endpoints for handling MCP HTTP Streaming transport.
2427
/// </summary>
2528
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
26-
/// <param name="runSession">Provides an optional asynchronous callback for handling new MCP sessions.</param>
29+
/// <param name="pattern">The route pattern prefix to map to.</param>
30+
/// <param name="configureOptionsAsync">Configure per-session options.</param>
31+
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
2732
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
28-
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, Func<HttpContext, IMcpServer, CancellationToken, Task>? runSession = null)
33+
public static IEndpointConventionBuilder MapMcp(
34+
this IEndpointRouteBuilder endpoints,
35+
[StringSyntax("Route")] string pattern = "",
36+
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
37+
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
38+
=> endpoints.MapMcp(RoutePatternFactory.Parse(pattern), configureOptionsAsync, runSessionAsync);
39+
40+
/// <summary>
41+
/// Sets up endpoints for handling MCP HTTP Streaming transport.
42+
/// </summary>
43+
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
44+
/// <param name="pattern">The route pattern prefix to map to.</param>
45+
/// <param name="configureOptionsAsync">Configure per-session options.</param>
46+
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
47+
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
48+
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints,
49+
RoutePattern pattern,
50+
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
51+
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
2952
{
3053
ConcurrentDictionary<string, SseResponseStreamTransport> _sessions = new(StringComparer.Ordinal);
3154

3255
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
33-
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
56+
var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
57+
var optionsFactory = endpoints.ServiceProvider.GetRequiredService<IOptionsFactory<McpServerOptions>>();
58+
var hostApplicationLifetime = endpoints.ServiceProvider.GetRequiredService<IHostApplicationLifetime>();
3459

35-
var routeGroup = endpoints.MapGroup("");
60+
var routeGroup = endpoints.MapGroup(pattern);
3661

3762
routeGroup.MapGet("/sse", async context =>
3863
{
39-
var response = context.Response;
40-
var requestAborted = context.RequestAborted;
64+
// If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout
65+
// which defaults to 30 seconds.
66+
using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
67+
var cancellationToken = sseCts.Token;
4168

69+
var response = context.Response;
4270
response.Headers.ContentType = "text/event-stream";
4371
response.Headers.CacheControl = "no-cache,no-store";
4472

73+
// Make sure we disable all response buffering for SSE
74+
context.Response.Headers.ContentEncoding = "identity";
75+
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();
76+
4577
var sessionId = MakeNewSessionId();
4678
await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}");
4779
if (!_sessions.TryAdd(sessionId, transport))
4880
{
4981
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
5082
}
5183

52-
try
84+
var options = optionsSnapshot.Value;
85+
if (configureOptionsAsync is not null)
5386
{
54-
// Make sure we disable all response buffering for SSE
55-
context.Response.Headers.ContentEncoding = "identity";
56-
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();
87+
options = optionsFactory.Create(Options.DefaultName);
88+
await configureOptionsAsync.Invoke(context, options, cancellationToken);
89+
}
5790

58-
var transportTask = transport.RunAsync(cancellationToken: requestAborted);
59-
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
91+
try
92+
{
93+
var transportTask = transport.RunAsync(cancellationToken);
6094

6195
try
6296
{
63-
runSession ??= RunSession;
64-
await runSession(context, server, requestAborted);
97+
await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider);
98+
context.Features.Set(mcpServer);
99+
100+
runSessionAsync ??= RunSession;
101+
await runSessionAsync(context, mcpServer, cancellationToken);
65102
}
66103
finally
67104
{
68105
await transport.DisposeAsync();
69106
await transportTask;
70107
}
71108
}
72-
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
109+
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
73110
{
74111
// RequestAborted always triggers when the client disconnects before a complete response body is written,
75112
// but this is how SSE connections are typically closed.

src/ModelContextProtocol/Client/McpClient.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, Mc
4949
{
5050
if (capabilities.NotificationHandlers is { } notificationHandlers)
5151
{
52-
NotificationHandlers.AddRange(notificationHandlers);
52+
NotificationHandlers.RegisterRange(notificationHandlers);
5353
}
5454

5555
if (capabilities.Sampling is { } samplingCapability)
@@ -106,7 +106,10 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
106106
{
107107
// Connect transport
108108
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
109-
StartSession(_sessionTransport);
109+
InitializeSession(_sessionTransport);
110+
// We don't want the ConnectAsync token to cancel the session after we've successfully connected.
111+
// The base class handles cleaning up the session in DisposeAsync without our help.
112+
StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None);
110113

111114
// Perform initialization sequence
112115
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

0 commit comments

Comments
 (0)