From 318fcafa8f62c3421ab3425e3d606b13b3ae4dff Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 9 Apr 2025 18:13:16 -0700 Subject: [PATCH 01/12] Simplify ModelContextProtocol.AspNetCore README --- src/ModelContextProtocol.AspNetCore/README.md | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/README.md b/src/ModelContextProtocol.AspNetCore/README.md index 457321d09..3452a58ed 100644 --- a/src/ModelContextProtocol.AspNetCore/README.md +++ b/src/ModelContextProtocol.AspNetCore/README.md @@ -30,20 +30,16 @@ dotnet add package ModelContextProtocol.AspNetCore --prerelease ```csharp // Program.cs -using ModelContextProtocol.Server; +?using ModelContextProtocol.Server; using System.ComponentModel; var builder = WebApplication.CreateBuilder(args); -builder.WebHost.ConfigureKestrel(options => -{ - options.ListenLocalhost(3001); -}); builder.Services.AddMcpServer().WithToolsFromAssembly(); var app = builder.Build(); app.MapMcp(); -app.Run(); +app.Run("http://localhost:3001"); [McpServerToolType] public static class EchoTool From cad85a65e5cb21d0e07e9b94f3b2ae88d3b40f6d Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 12:39:00 -0700 Subject: [PATCH 02/12] Add StreamableHttpHandler --- .../HttpMcpServerBuilderExtensions.cs | 34 ++++ .../HttpServerTransportOptions.cs | 24 +++ .../McpEndpointRouteBuilderExtensions.cs | 145 +----------------- .../StreamableHttpHandler.cs | 138 +++++++++++++++++ .../Server/MapMcpTests.cs | 2 + .../SseIntegrationTests.cs | 25 ++- 6 files changed, 223 insertions(+), 145 deletions(-) create mode 100644 src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs create mode 100644 src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs create mode 100644 src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs new file mode 100644 index 000000000..f67c57afd --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -0,0 +1,34 @@ +using Microsoft.Extensions.DependencyInjection.Extensions; +using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Server; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Provides methods for configuring HTTP MCP servers via dependency injection. +/// +public static class HttpMcpServerBuilderExtensions +{ + /// + /// Adds the services necessary for + /// to handle MCP requests and sessions using the MCP HTTP Streaming transport. For more information on configuring the underlying HTTP server + /// to control things like port binding custom TLS certificates, see the Minimal APIs quick reference. + /// + /// The builder instance. + /// Configures options for the HTTP Streaming transport. This allows configuring per-session + /// and running logic before and after a session. + /// The builder provided in . + /// is . + public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder, Action? configureOptions = null) + { + ArgumentNullException.ThrowIfNull(builder); + builder.Services.TryAddSingleton(); + + if (configureOptions is not null) + { + builder.Services.Configure(configureOptions); + } + + return builder; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs new file mode 100644 index 000000000..850dac244 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -0,0 +1,24 @@ +using Microsoft.AspNetCore.Http; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Configuration options for . +/// which implements the Streaming HTTP transport for the Model Context Protocol. +/// See the protocol specification for details on the Streamable HTTP transport. +/// +public class HttpServerTransportOptions +{ + /// + /// Gets or sets an optional asynchronous callback to configure per-session + /// with access to the of the request that initiated the session. + /// + public Func? ConfigureSessionOptions { get; set; } + + /// + /// Gets or sets an optional asynchronous callback for running new MCP sessions manually. + /// This is useful for running logic before a sessions starts and after it completes. + /// + public Func? RunSessionHandler { get; set; } +} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 818af8ba5..16ccbb282 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -1,19 +1,7 @@ -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.Routing; -using Microsoft.AspNetCore.Routing.Patterns; -using Microsoft.AspNetCore.WebUtilities; +using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Server; -using ModelContextProtocol.Utils.Json; -using System.Collections.Concurrent; +using ModelContextProtocol.AspNetCore; using System.Diagnostics.CodeAnalysis; -using System.Security.Cryptography; namespace Microsoft.AspNetCore.Builder; @@ -24,136 +12,19 @@ public static class McpEndpointRouteBuilderExtensions { /// /// Sets up endpoints for handling MCP HTTP Streaming transport. + /// See the protocol specification for details about the Streamable HTTP transport. /// /// The web application to attach MCP HTTP endpoints. /// The route pattern prefix to map to. - /// Configure per-session options. - /// Provides an optional asynchronous callback for handling new MCP sessions. /// Returns a builder for configuring additional endpoint conventions like authorization policies. - public static IEndpointConventionBuilder MapMcp( - this IEndpointRouteBuilder endpoints, - [StringSyntax("Route")] string pattern = "", - Func? configureOptionsAsync = null, - Func? runSessionAsync = null) - => endpoints.MapMcp(RoutePatternFactory.Parse(pattern), configureOptionsAsync, runSessionAsync); - - /// - /// Sets up endpoints for handling MCP HTTP Streaming transport. - /// - /// The web application to attach MCP HTTP endpoints. - /// The route pattern prefix to map to. - /// Configure per-session options. - /// Provides an optional asynchronous callback for handling new MCP sessions. - /// Returns a builder for configuring additional endpoint conventions like authorization policies. - public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, - RoutePattern pattern, - Func? configureOptionsAsync = null, - Func? runSessionAsync = null) + public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern = "") { - ConcurrentDictionary _sessions = new(StringComparer.Ordinal); - - var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); - var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService>(); - var optionsFactory = endpoints.ServiceProvider.GetRequiredService>(); - var hostApplicationLifetime = endpoints.ServiceProvider.GetRequiredService(); + var handler = endpoints.ServiceProvider.GetService() ?? + throw new InvalidOperationException("You must call WithHttpTransport(). Unable to find required services. Call builder.Services.AddMcpServer().WithHttpTransport() in application startup code."); var routeGroup = endpoints.MapGroup(pattern); - - routeGroup.MapGet("/sse", async context => - { - // If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout - // which defaults to 30 seconds. - using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping); - var cancellationToken = sseCts.Token; - - var response = context.Response; - response.Headers.ContentType = "text/event-stream"; - response.Headers.CacheControl = "no-cache,no-store"; - - // Make sure we disable all response buffering for SSE - context.Response.Headers.ContentEncoding = "identity"; - context.Features.GetRequiredFeature().DisableBuffering(); - - var sessionId = MakeNewSessionId(); - await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}"); - if (!_sessions.TryAdd(sessionId, transport)) - { - throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); - } - - var options = optionsSnapshot.Value; - if (configureOptionsAsync is not null) - { - options = optionsFactory.Create(Options.DefaultName); - await configureOptionsAsync.Invoke(context, options, cancellationToken); - } - - try - { - var transportTask = transport.RunAsync(cancellationToken); - - try - { - await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider); - context.Features.Set(mcpServer); - - runSessionAsync ??= RunSession; - await runSessionAsync(context, mcpServer, cancellationToken); - } - finally - { - await transport.DisposeAsync(); - await transportTask; - } - } - catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) - { - // RequestAborted always triggers when the client disconnects before a complete response body is written, - // but this is how SSE connections are typically closed. - } - finally - { - _sessions.TryRemove(sessionId, out _); - } - }); - - routeGroup.MapPost("/message", async context => - { - if (!context.Request.Query.TryGetValue("sessionId", out var sessionId)) - { - await Results.BadRequest("Missing sessionId query parameter.").ExecuteAsync(context); - return; - } - - if (!_sessions.TryGetValue(sessionId.ToString(), out var transport)) - { - await Results.BadRequest($"Session ID not found.").ExecuteAsync(context); - return; - } - - var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted); - if (message is null) - { - await Results.BadRequest("No message in request body.").ExecuteAsync(context); - return; - } - - await transport.OnMessageReceivedAsync(message, context.RequestAborted); - context.Response.StatusCode = StatusCodes.Status202Accepted; - await context.Response.WriteAsync("Accepted"); - }); - + routeGroup.MapGet("/sse", handler.HandleRequestAsync); + routeGroup.MapPost("/message", handler.HandleRequestAsync); return routeGroup; } - - private static Task RunSession(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) - => session.RunAsync(requestAborted); - - private static string MakeNewSessionId() - { - // 128 bits - Span buffer = stackalloc byte[16]; - RandomNumberGenerator.Fill(buffer); - return WebEncoders.Base64UrlEncode(buffer); - } } diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs new file mode 100644 index 000000000..4b5641032 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -0,0 +1,138 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils.Json; +using System.Collections.Concurrent; +using System.Security.Cryptography; + +namespace ModelContextProtocol.AspNetCore; + +internal sealed class StreamableHttpHandler( + IOptions mcpServerOptionsSnapshot, + IOptionsFactory mcpServerOptionsFactory, + IOptions httpMcpServerOptions, + IHostApplicationLifetime hostApplicationLifetime, + ILoggerFactory loggerFactory) +{ + + private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); + private readonly ILogger _logger = loggerFactory.CreateLogger(); + + public async Task HandleRequestAsync(HttpContext context) + { + if (context.Request.Method == HttpMethods.Get) + { + await HandleSseRequestAsync(context); + } + else if (context.Request.Method == HttpMethods.Post) + { + await HandleMessageRequestAsync(context); + } + else + { + context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed; + await context.Response.WriteAsync("Method Not Allowed"); + } + } + + public async Task HandleSseRequestAsync(HttpContext context) + { + // If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout + // which defaults to 30 seconds. + using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping); + var cancellationToken = sseCts.Token; + + var response = context.Response; + response.Headers.ContentType = "text/event-stream"; + response.Headers.CacheControl = "no-cache,no-store"; + + // Make sure we disable all response buffering for SSE + context.Response.Headers.ContentEncoding = "identity"; + context.Features.GetRequiredFeature().DisableBuffering(); + + var sessionId = MakeNewSessionId(); + await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}"); + if (!_sessions.TryAdd(sessionId, transport)) + { + throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); + } + + var mcpServerOptions = mcpServerOptionsSnapshot.Value; + if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions) + { + mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName); + await configureSessionOptions(context, mcpServerOptions, cancellationToken); + } + + try + { + var transportTask = transport.RunAsync(cancellationToken); + + try + { + await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); + context.Features.Set(mcpServer); + + var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync; + await runSessionAsync(context, mcpServer, cancellationToken); + } + finally + { + await transport.DisposeAsync(); + await transportTask; + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // RequestAborted always triggers when the client disconnects before a complete response body is written, + // but this is how SSE connections are typically closed. + } + finally + { + _sessions.TryRemove(sessionId, out _); + } + } + + public async Task HandleMessageRequestAsync(HttpContext context) + { + if (!context.Request.Query.TryGetValue("sessionId", out var sessionId)) + { + await Results.BadRequest("Missing sessionId query parameter.").ExecuteAsync(context); + return; + } + + if (!_sessions.TryGetValue(sessionId.ToString(), out var transport)) + { + await Results.BadRequest($"Session ID not found.").ExecuteAsync(context); + return; + } + + var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted); + if (message is null) + { + await Results.BadRequest("No message in request body.").ExecuteAsync(context); + return; + } + + await transport.OnMessageReceivedAsync(message, context.RequestAborted); + context.Response.StatusCode = StatusCodes.Status202Accepted; + await context.Response.WriteAsync("Accepted"); + } + + private static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) + => session.RunAsync(requestAborted); + + private static string MakeNewSessionId() + { + // 128 bits + Span buffer = stackalloc byte[16]; + RandomNumberGenerator.Fill(buffer); + return WebEncoders.Base64UrlEncode(buffer); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Server/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Server/MapMcpTests.cs index 5a3c4181f..709dee2c5 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Server/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Server/MapMcpTests.cs @@ -1,4 +1,5 @@ using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Tests.Utils; namespace ModelContextProtocol.AspNetCore.Tests.Server; @@ -8,6 +9,7 @@ public class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTe [Fact] public async Task Allows_Customizing_Route() { + Builder.Services.AddMcpServer().WithHttpTransport(); await using var app = Builder.Build(); app.MapMcp("/mcp"); await app.StartAsync(TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 83c601aa4..a1195bfea 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -34,6 +34,7 @@ private Task ConnectMcpClient(HttpClient httpClient, McpClientOption [Fact] public async Task ConnectAndReceiveMessage_InMemoryServer() { + Builder.Services.AddMcpServer().WithHttpTransport(); await using var app = Builder.Build(); app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); @@ -68,16 +69,23 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() { var receivedNotification = new TaskCompletionSource(); - await using var app = Builder.Build(); - app.MapMcp(runSessionAsync: (httpContext, mcpServer, cancellationToken) => - { - mcpServer.RegisterNotificationHandler("test/notification", async (notification, cancellationToken) => + Builder.Services.AddMcpServer() + .WithHttpTransport(httpTransportOptions => { - Assert.Equal("Hello from client!", notification.Params?["message"]?.GetValue()); - await mcpServer.SendNotificationAsync("test/notification", new Envelope { Message = "Hello from server!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: cancellationToken); + httpTransportOptions.RunSessionHandler = (httpContext, mcpServer, cancellationToken) => + { + // We could also use ServerCapabilities.NotificationHandlers, but it's good to have some test coverage of RunSessionHandler. + mcpServer.RegisterNotificationHandler("test/notification", async (notification, cancellationToken) => + { + Assert.Equal("Hello from client!", notification.Params?["message"]?.GetValue()); + await mcpServer.SendNotificationAsync("test/notification", new Envelope { Message = "Hello from server!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: cancellationToken); + }); + return mcpServer.RunAsync(cancellationToken); + }; }); - return mcpServer.RunAsync(cancellationToken); - }); + + await using var app = Builder.Build(); + app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); using var httpClient = CreateHttpClient(); @@ -107,6 +115,7 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes() { Interlocked.Increment(ref firstOptionsCallbackCallCount); }) + .WithHttpTransport() .WithTools(); Builder.Services.AddMcpServer(options => From 3b13429cc7ed74438a0235fe075c0a40932683b6 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 14:48:47 -0700 Subject: [PATCH 03/12] Use "docker info" in CheckIsDockerAvailable --- tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs index d5fff66f0..ffd8859a0 100644 --- a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs +++ b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs @@ -65,7 +65,8 @@ private static bool CheckIsDockerAvailable() ProcessStartInfo processStartInfo = new() { FileName = "docker", - Arguments = "--version", + // "docker info" returns a non-zero exit code if docker engine is not running. + Arguments = "info", UseShellExecute = false, }; From 6b45f35ad85bb173ee5af45b2921873ae85267fb Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 14:52:29 -0700 Subject: [PATCH 04/12] Call WithHttpTransport in samples, tests and README --- samples/AspNetCoreSseServer/Program.cs | 1 + src/ModelContextProtocol.AspNetCore/README.md | 4 +++- tests/ModelContextProtocol.TestSseServer/Program.cs | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/samples/AspNetCoreSseServer/Program.cs b/samples/AspNetCoreSseServer/Program.cs index 306a6e8f7..6bf906ae4 100644 --- a/samples/AspNetCoreSseServer/Program.cs +++ b/samples/AspNetCoreSseServer/Program.cs @@ -2,6 +2,7 @@ var builder = WebApplication.CreateBuilder(args); builder.Services.AddMcpServer() + .WithHttpTransport() .WithTools() .WithTools(); diff --git a/src/ModelContextProtocol.AspNetCore/README.md b/src/ModelContextProtocol.AspNetCore/README.md index 3452a58ed..d2f79a3b7 100644 --- a/src/ModelContextProtocol.AspNetCore/README.md +++ b/src/ModelContextProtocol.AspNetCore/README.md @@ -34,7 +34,9 @@ dotnet add package ModelContextProtocol.AspNetCore --prerelease using System.ComponentModel; var builder = WebApplication.CreateBuilder(args); -builder.Services.AddMcpServer().WithToolsFromAssembly(); +builder.Services.AddMcpServer() + .WithHttpTransport() + .WithToolsFromAssembly(); var app = builder.Build(); app.MapMcp(); diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index f364c4a12..0c04c8477 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -408,7 +408,8 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide builder.Logging.AddProvider(loggerProvider); } - builder.Services.AddMcpServer(ConfigureOptions); + builder.Services.AddMcpServer(ConfigureOptions) + .WithHttpTransport(); var app = builder.Build(); app.UseRouting(); From 82813d2bfae2997799feb2850788b5d4e35ae74d Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 15:20:56 -0700 Subject: [PATCH 05/12] Cleanup test namespaces --- .../{Server => }/MapMcpTests.cs | 4 ++-- .../SseIntegrationTests.cs | 4 ++-- .../SseServerIntegrationTestFixture.cs | 4 ++-- .../SseServerIntegrationTests.cs | 2 +- .../Utils/KestrelInMemoryConnection.cs | 2 +- .../Utils/KestrelInMemoryTest.cs | 3 ++- .../Utils/KestrelInMemoryTransport.cs | 2 +- .../ModelContextProtocol.AspNetCore.Tests/Utils/LoggedTest.cs | 1 - .../Utils/XunitLoggerProvider.cs | 2 +- tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs | 1 - tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs | 2 +- 11 files changed, 13 insertions(+), 14 deletions(-) rename tests/ModelContextProtocol.AspNetCore.Tests/{Server => }/MapMcpTests.cs (87%) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Server/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs similarity index 87% rename from tests/ModelContextProtocol.AspNetCore.Tests/Server/MapMcpTests.cs rename to tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 709dee2c5..4ed079e12 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Server/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -1,8 +1,8 @@ using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; -using ModelContextProtocol.Tests.Utils; +using ModelContextProtocol.AspNetCore.Tests.Utils; -namespace ModelContextProtocol.AspNetCore.Tests.Server; +namespace ModelContextProtocol.AspNetCore.Tests; public class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index a1195bfea..e7ed73ef6 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -4,16 +4,16 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.Utils.Json; using System.Text.Json.Serialization; using TestServerWithHosting.Tools; -namespace ModelContextProtocol.Tests; +namespace ModelContextProtocol.AspNetCore.Tests; public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper) { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index 286ae3cfa..41b8d8fa7 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -1,11 +1,11 @@ using Microsoft.Extensions.Logging; +using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Test.Utils; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.TestSseServer; -namespace ModelContextProtocol.Tests; +namespace ModelContextProtocol.AspNetCore.Tests; public class SseServerIntegrationTestFixture : IAsyncDisposable { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index 280a757d8..10a6316a9 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -4,7 +4,7 @@ using System.Net; using System.Text; -namespace ModelContextProtocol.Tests; +namespace ModelContextProtocol.AspNetCore.Tests; public class SseServerIntegrationTests : LoggedTest, IClassFixture { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs index c823261cf..0269ea7bf 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs @@ -2,7 +2,7 @@ using Microsoft.AspNetCore.Http.Features; using System.IO.Pipelines; -namespace ModelContextProtocol.Tests.Utils; +namespace ModelContextProtocol.AspNetCore.Tests.Utils; public sealed class KestrelInMemoryConnection : ConnectionContext { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs index 5e440dcdd..00280e22c 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs @@ -2,8 +2,9 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; +using ModelContextProtocol.Tests.Utils; -namespace ModelContextProtocol.Tests.Utils; +namespace ModelContextProtocol.AspNetCore.Tests.Utils; public class KestrelInMemoryTest : LoggedTest { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs index 586b1650a..399e9a833 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs @@ -2,7 +2,7 @@ using System.Net; using System.Threading.Channels; -namespace ModelContextProtocol.Tests.Utils; +namespace ModelContextProtocol.AspNetCore.Tests.Utils; public sealed class KestrelInMemoryTransport : IConnectionListenerFactory, IConnectionListener { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/LoggedTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/LoggedTest.cs index aa1ecbc27..a2e9e2ba2 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/LoggedTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/LoggedTest.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.Logging; -using ModelContextProtocol.Test.Utils; namespace ModelContextProtocol.Tests.Utils; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/XunitLoggerProvider.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/XunitLoggerProvider.cs index c76d2649a..f66a828a6 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/XunitLoggerProvider.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/XunitLoggerProvider.cs @@ -2,7 +2,7 @@ using System.Text; using Microsoft.Extensions.Logging; -namespace ModelContextProtocol.Test.Utils; +namespace ModelContextProtocol.Tests.Utils; public class XunitLoggerProvider(ITestOutputHelper output) : ILoggerProvider { diff --git a/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs b/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs index aa1ecbc27..a2e9e2ba2 100644 --- a/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs +++ b/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.Logging; -using ModelContextProtocol.Test.Utils; namespace ModelContextProtocol.Tests.Utils; diff --git a/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs b/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs index c76d2649a..f66a828a6 100644 --- a/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs +++ b/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs @@ -2,7 +2,7 @@ using System.Text; using Microsoft.Extensions.Logging; -namespace ModelContextProtocol.Test.Utils; +namespace ModelContextProtocol.Tests.Utils; public class XunitLoggerProvider(ITestOutputHelper output) : ILoggerProvider { From 65a82ec0186475a6147dbb0d8c3467f6550a712b Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 15:35:42 -0700 Subject: [PATCH 06/12] Add CanConnect_WithMcpClient_AfterCustomizingRoute test --- .../StreamableHttpHandler.cs | 2 +- .../MapMcpTests.cs | 32 +++++++++++++++++++ .../SseIntegrationTests.cs | 1 - 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 4b5641032..10dd01f9a 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -57,7 +57,7 @@ public async Task HandleSseRequestAsync(HttpContext context) context.Features.GetRequiredFeature().DisableBuffering(); var sessionId = MakeNewSessionId(); - await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}"); + await using var transport = new SseResponseStreamTransport(response.Body, $"message?sessionId={sessionId}"); if (!_sessions.TryAdd(sessionId, transport)) { throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 4ed079e12..8c5b32ed5 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -1,6 +1,8 @@ using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; namespace ModelContextProtocol.AspNetCore.Tests; @@ -11,11 +13,41 @@ public async Task Allows_Customizing_Route() { Builder.Services.AddMcpServer().WithHttpTransport(); await using var app = Builder.Build(); + app.MapMcp("/mcp"); + await app.StartAsync(TestContext.Current.CancellationToken); using var httpClient = CreateHttpClient(); using var response = await httpClient.GetAsync("http://localhost/mcp/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); response.EnsureSuccessStatusCode(); } + + [Fact] + public async Task CanConnect_WithMcpClient_AfterCustomizingRoute() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "TestCustomRouteServer", + Version = "1.0.0", + }; + }).WithHttpTransport(); + await using var app = Builder.Build(); + + app.MapMcp("/mcp"); + + await app.StartAsync(TestContext.Current.CancellationToken); + + using var httpClient = CreateHttpClient(); + var sseClientTransportOptions = new SseClientTransportOptions() + { + Endpoint = new Uri("http://localhost/mcp/sse"), + }; + await using var transport = new SseClientTransport(sseClientTransportOptions, httpClient, LoggerFactory); + var mcpClient = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index e7ed73ef6..9beb88cfa 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -30,7 +30,6 @@ private Task ConnectMcpClient(HttpClient httpClient, McpClientOption LoggerFactory, TestContext.Current.CancellationToken); - [Fact] public async Task ConnectAndReceiveMessage_InMemoryServer() { From 5de8bc9587e957d210f99d0e9b5588856603a372 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 16:01:09 -0700 Subject: [PATCH 07/12] Simplify relative URI handling --- .../Transport/SseClientSessionTransport.cs | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 5348995e7..893ea2c16 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -285,21 +285,8 @@ private void HandleEndpointEvent(string data) return; } - // Check if data is absolute URI - if (data.StartsWith("http://", StringComparison.OrdinalIgnoreCase) || data.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) - { - // Since the endpoint is an absolute URI, we can use it directly - _messageEndpoint = new Uri(data); - } - else - { - // If the endpoint is a relative URI, we need to combine it with the relative path of the SSE endpoint - var baseUriBuilder = new UriBuilder(_sseEndpoint); - - - // Instead of manually concatenating strings, use the Uri class's composition capabilities - _messageEndpoint = new Uri(baseUriBuilder.Uri, data); - } + // If data is an absolute URL, the Uri will be constructed entirely from it and not the _sseEndpoint. + _messageEndpoint = new Uri(_sseEndpoint, data); // Set connected state SetConnected(true); From 37a1d4e5675d989eaff593041cd514741a099782 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 16:44:19 -0700 Subject: [PATCH 08/12] Handle request made directly to the MapMcp route pattern --- .../McpEndpointRouteBuilderExtensions.cs | 1 + .../MapMcpTests.cs | 89 ++++++++++++++++--- .../SseIntegrationTests.cs | 12 +-- .../Utils/KestrelInMemoryTest.cs | 19 ++-- 4 files changed, 93 insertions(+), 28 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 16ccbb282..ac424cc8b 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -23,6 +23,7 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo throw new InvalidOperationException("You must call WithHttpTransport(). Unable to find required services. Call builder.Services.AddMcpServer().WithHttpTransport() in application startup code."); var routeGroup = endpoints.MapGroup(pattern); + routeGroup.MapGet("", handler.HandleRequestAsync); routeGroup.MapGet("/sse", handler.HandleRequestAsync); routeGroup.MapPost("/message", handler.HandleRequestAsync); return routeGroup; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 8c5b32ed5..508802259 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -1,13 +1,36 @@ using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Security.Claims; namespace ModelContextProtocol.AspNetCore.Tests; public class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) { + private async Task ConnectAsync(string? path = null) + { + var sseClientTransportOptions = new SseClientTransportOptions() + { + Endpoint = new Uri($"http://localhost{path}"), + }; + await using var transport = new SseClientTransport(sseClientTransportOptions, HttpClient, LoggerFactory); + return await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task MapMcp_ThrowsInvalidOperationException_IfWithHttpTransportIsNotCalled() + { + Builder.Services.AddMcpServer(); + await using var app = Builder.Build(); + var exception = Assert.Throws(() => app.MapMcp()); + Assert.StartsWith("You must call WithHttpTransport()", exception.Message); + } + [Fact] public async Task Allows_Customizing_Route() { @@ -18,13 +41,16 @@ public async Task Allows_Customizing_Route() await app.StartAsync(TestContext.Current.CancellationToken); - using var httpClient = CreateHttpClient(); - using var response = await httpClient.GetAsync("http://localhost/mcp/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + using var response = await HttpClient.GetAsync("http://localhost/mcp/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); response.EnsureSuccessStatusCode(); } - [Fact] - public async Task CanConnect_WithMcpClient_AfterCustomizingRoute() + [Theory] + [InlineData("/a", "/a/sse")] + [InlineData("/a", "/a/")] + [InlineData("/a/", "/a/sse")] + [InlineData("/a/", "/a/")] + public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePattern, string requestPath) { Builder.Services.AddMcpServer(options => { @@ -36,18 +62,57 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute() }).WithHttpTransport(); await using var app = Builder.Build(); - app.MapMcp("/mcp"); + app.MapMcp(routePattern); await app.StartAsync(TestContext.Current.CancellationToken); - using var httpClient = CreateHttpClient(); - var sseClientTransportOptions = new SseClientTransportOptions() - { - Endpoint = new Uri("http://localhost/mcp/sse"), - }; - await using var transport = new SseClientTransport(sseClientTransportOptions, httpClient, LoggerFactory); - var mcpClient = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + var mcpClient = await ConnectAsync(requestPath); Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); } + + [Fact] + public async Task Can_UseHttpContextAccessor_InTool() + { + Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); + + Builder.Services.AddHttpContextAccessor(); + + await using var app = Builder.Build(); + + app.Use(next => + { + return async context => + { + context.User = new ClaimsPrincipal(new ClaimsIdentity([new Claim("name", "TestUser")], "TestAuthType", "name", "role")); + await next(context); + }; + }); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var mcpClient = await ConnectAsync(); + + var response = await mcpClient.CallToolAsync( + "EchoWithUserName", + new Dictionary() { ["message"] = "Hello world!" }, + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(response.Content); + Assert.Equal("TestUser: Hello world!", content.Text); + } + + [McpServerToolType] + private class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor) + { + [McpServerTool, Description("Echoes the input back to the client with their user name.")] + public string EchoWithUserName(string message) + { + var httpContext = contextAccessor.HttpContext ?? throw new Exception("HttpContext unavailable!"); + var userName = httpContext.User.Identity?.Name ?? "anonymous"; + return $"{userName}: {message}"; + } + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 9beb88cfa..d7ac0f594 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -38,8 +38,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - using var httpClient = CreateHttpClient(); - await using var mcpClient = await ConnectMcpClient(httpClient); + await using var mcpClient = await ConnectMcpClient(HttpClient); // Send a test message through POST endpoint await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken); @@ -54,8 +53,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU MapAbsoluteEndpointUriMcp(app); await app.StartAsync(TestContext.Current.CancellationToken); - using var httpClient = CreateHttpClient(); - await using var mcpClient = await ConnectMcpClient(httpClient); + await using var mcpClient = await ConnectMcpClient(HttpClient); // Send a test message through POST endpoint await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken); @@ -87,8 +85,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - using var httpClient = CreateHttpClient(); - await using var mcpClient = await ConnectMcpClient(httpClient); + await using var mcpClient = await ConnectMcpClient(HttpClient); mcpClient.RegisterNotificationHandler("test/notification", (args, ca) => { @@ -128,8 +125,7 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - using var httpClient = CreateHttpClient(); - await using var mcpClient = await ConnectMcpClient(httpClient); + await using var mcpClient = await ConnectMcpClient(HttpClient); // Options can be lazily initialized, but they must be instantiated by the time an MCP client can finish connecting. // Callbacks can be called multiple times if configureOptionsAsync is configured, because that uses the IOptionsFactory, diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs index 00280e22c..b10a0d674 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs @@ -19,21 +19,24 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) Builder.Services.RemoveAll(); Builder.Services.AddSingleton(_inMemoryTransport); Builder.Services.AddSingleton(LoggerProvider); - } - - public WebApplicationBuilder Builder { get; } - public HttpClient CreateHttpClient() - { - var socketsHttpHandler = new SocketsHttpHandler() + HttpClient = new HttpClient(new SocketsHttpHandler() { ConnectCallback = (context, token) => { var connection = _inMemoryTransport.CreateConnection(); return new(connection.ClientStream); }, - }; + }); + } - return new HttpClient(socketsHttpHandler); + public WebApplicationBuilder Builder { get; } + + public HttpClient HttpClient { get; } + + public override void Dispose() + { + HttpClient.Dispose(); + base.Dispose(); } } From f47ad26be1c27e3fc81229d8e5bde3a9821dd3d2 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 17:42:44 -0700 Subject: [PATCH 09/12] Add Messages_FromNewUser_AreRejected test --- .../HttpMcpSession.cs | 39 +++++++++++++++++++ .../StreamableHttpHandler.cs | 15 +++++-- .../MapMcpTests.cs | 38 +++++++++++++++++- 3 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs new file mode 100644 index 000000000..56cbad375 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -0,0 +1,39 @@ +using ModelContextProtocol.Protocol.Transport; +using System.Security.Claims; + +namespace ModelContextProtocol.AspNetCore; + +internal class HttpMcpSession +{ + public HttpMcpSession(SseResponseStreamTransport transport, ClaimsPrincipal user) + { + Transport = transport; + UserIdClaim = GetUserIdClaim(user); + } + + public SseResponseStreamTransport Transport { get; } + public (string ClaimType, string ClaimValue)? UserIdClaim { get; } + + public bool HasSameUserId(ClaimsPrincipal user) + => UserIdClaim?.ClaimValue == GetUserIdClaim(user)?.ClaimValue; + + // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. + // However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than + // verifying antiforgery tokens from
posts. + private static (string ClaimType, string ClaimValue)? GetUserIdClaim(ClaimsPrincipal user) + { + if (user?.Identity?.IsAuthenticated != true) + { + return null; + } + + var claim = user.FindFirst(ClaimTypes.NameIdentifier) ?? user.FindFirst("sub") ?? user.FindFirst(ClaimTypes.Upn); + + if (claim is { } idClaim) + { + return (idClaim.Type, idClaim.Value); + } + + return null; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 10dd01f9a..a8455d871 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -21,7 +21,7 @@ internal sealed class StreamableHttpHandler( ILoggerFactory loggerFactory) { - private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); private readonly ILogger _logger = loggerFactory.CreateLogger(); public async Task HandleRequestAsync(HttpContext context) @@ -58,7 +58,8 @@ public async Task HandleSseRequestAsync(HttpContext context) var sessionId = MakeNewSessionId(); await using var transport = new SseResponseStreamTransport(response.Body, $"message?sessionId={sessionId}"); - if (!_sessions.TryAdd(sessionId, transport)) + var httpMcpSession = new HttpMcpSession(transport, context.User); + if (!_sessions.TryAdd(sessionId, httpMcpSession)) { throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); } @@ -107,12 +108,18 @@ public async Task HandleMessageRequestAsync(HttpContext context) return; } - if (!_sessions.TryGetValue(sessionId.ToString(), out var transport)) + if (!_sessions.TryGetValue(sessionId.ToString(), out var httpMcpSession)) { await Results.BadRequest($"Session ID not found.").ExecuteAsync(context); return; } + if (!httpMcpSession.HasSameUserId(context.User)) + { + await Results.Forbid().ExecuteAsync(context); + return; + } + var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted); if (message is null) { @@ -120,7 +127,7 @@ public async Task HandleMessageRequestAsync(HttpContext context) return; } - await transport.OnMessageReceivedAsync(message, context.RequestAborted); + await httpMcpSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted); context.Response.StatusCode = StatusCodes.Status202Accepted; await context.Response.WriteAsync("Accepted"); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 508802259..f7cb2c8a6 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -6,6 +6,7 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Server; using System.ComponentModel; +using System.Net; using System.Security.Claims; namespace ModelContextProtocol.AspNetCore.Tests; @@ -84,7 +85,7 @@ public async Task Can_UseHttpContextAccessor_InTool() { return async context => { - context.User = new ClaimsPrincipal(new ClaimsIdentity([new Claim("name", "TestUser")], "TestAuthType", "name", "role")); + context.User = CreateUser("TestUser"); await next(context); }; }); @@ -104,6 +105,41 @@ public async Task Can_UseHttpContextAccessor_InTool() Assert.Equal("TestUser: Hello world!", content.Text); } + + [Fact] + public async Task Messages_FromNewUser_AreRejected() + { + Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); + + // Add an authentication scheme that will send a 403 Forbidden response. + Builder.Services.AddAuthentication().AddBearerToken(); + Builder.Services.AddHttpContextAccessor(); + + await using var app = Builder.Build(); + + app.Use(next => + { + var i = 0; + return async context => + { + context.User = CreateUser($"TestUser{Interlocked.Increment(ref i)}"); + await next(context); + }; + }); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var httpRequestException = await Assert.ThrowsAsync(() => ConnectAsync()); + Assert.Equal(HttpStatusCode.Forbidden, httpRequestException.StatusCode); + } + + private ClaimsPrincipal CreateUser(string name) + => new ClaimsPrincipal(new ClaimsIdentity( + [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)], + "TestAuthType", "name", "role")); + [McpServerToolType] private class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor) { From b5a1bfd1aea10a4e8fe53e1bf97ac1733c262d71 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 17:58:26 -0700 Subject: [PATCH 10/12] Fix README --- src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs | 6 +++--- src/ModelContextProtocol.AspNetCore/README.md | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs index 56cbad375..1ca86105a 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -12,7 +12,7 @@ public HttpMcpSession(SseResponseStreamTransport transport, ClaimsPrincipal user } public SseResponseStreamTransport Transport { get; } - public (string ClaimType, string ClaimValue)? UserIdClaim { get; } + public (string ClaimType, string ClaimValue, string Issuer)? UserIdClaim { get; } public bool HasSameUserId(ClaimsPrincipal user) => UserIdClaim?.ClaimValue == GetUserIdClaim(user)?.ClaimValue; @@ -20,7 +20,7 @@ public bool HasSameUserId(ClaimsPrincipal user) // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. // However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than // verifying antiforgery tokens from posts. - private static (string ClaimType, string ClaimValue)? GetUserIdClaim(ClaimsPrincipal user) + private static (string ClaimType, string ClaimValue, string Issuer)? GetUserIdClaim(ClaimsPrincipal user) { if (user?.Identity?.IsAuthenticated != true) { @@ -31,7 +31,7 @@ private static (string ClaimType, string ClaimValue)? GetUserIdClaim(ClaimsPrinc if (claim is { } idClaim) { - return (idClaim.Type, idClaim.Value); + return (idClaim.Type, idClaim.Value, idClaim.Issuer); } return null; diff --git a/src/ModelContextProtocol.AspNetCore/README.md b/src/ModelContextProtocol.AspNetCore/README.md index d2f79a3b7..317acb72a 100644 --- a/src/ModelContextProtocol.AspNetCore/README.md +++ b/src/ModelContextProtocol.AspNetCore/README.md @@ -30,7 +30,7 @@ dotnet add package ModelContextProtocol.AspNetCore --prerelease ```csharp // Program.cs -?using ModelContextProtocol.Server; +using ModelContextProtocol.Server; using System.ComponentModel; var builder = WebApplication.CreateBuilder(args); From cc2a4f1160c2e1eb42dab4499fbc96263c88e8db Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 11 Apr 2025 18:02:23 -0700 Subject: [PATCH 11/12] Shorten UserIdClaim ValueTuple names --- src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs index 1ca86105a..216962a8b 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -12,15 +12,15 @@ public HttpMcpSession(SseResponseStreamTransport transport, ClaimsPrincipal user } public SseResponseStreamTransport Transport { get; } - public (string ClaimType, string ClaimValue, string Issuer)? UserIdClaim { get; } + public (string Type, string Value, string Issuer)? UserIdClaim { get; } public bool HasSameUserId(ClaimsPrincipal user) - => UserIdClaim?.ClaimValue == GetUserIdClaim(user)?.ClaimValue; + => UserIdClaim == GetUserIdClaim(user); // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. // However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than // verifying antiforgery tokens from posts. - private static (string ClaimType, string ClaimValue, string Issuer)? GetUserIdClaim(ClaimsPrincipal user) + private static (string Type, string Value, string Issuer)? GetUserIdClaim(ClaimsPrincipal user) { if (user?.Identity?.IsAuthenticated != true) { From d9e737cca8181a696966c20d8d76ff513a68df08 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 13 Apr 2025 15:00:12 -0700 Subject: [PATCH 12/12] Remove MaxReconnectAttempts and ReconnectDelay from SseClientTransportOptions - Add proper AdditionalHeaders support --- .../Transport/SseClientSessionTransport.cs | 117 +++++++++--------- .../Transport/SseClientTransportOptions.cs | 31 ----- .../SseIntegrationTests.cs | 93 ++++++++++++-- .../Transport/SseClientTransportTests.cs | 13 +- 4 files changed, 142 insertions(+), 112 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 3d5cdf2bb..168e25818 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -101,11 +101,12 @@ public override async Task SendMessageAsync( messageId = messageWithId.Id.ToString(); } - var response = await _httpClient.PostAsync( - _messageEndpoint, - content, - cancellationToken - ).ConfigureAwait(false); + var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) + { + Content = content, + }; + CopyAdditionalHeaders(httpRequestMessage.Headers); + var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); response.EnsureSuccessStatusCode(); @@ -182,72 +183,52 @@ public override async ValueTask DisposeAsync() private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { - int reconnectAttempts = 0; - - while (!cancellationToken.IsCancellationRequested && !IsConnected) + try { - try - { - using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); - request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + CopyAdditionalHeaders(request.Headers); - if (_options.AdditionalHeaders != null) - { - foreach (var header in _options.AdditionalHeaders) - { - request.Headers.Add(header.Key, header.Value); - } - } - - using var response = await _httpClient.SendAsync( - request, - HttpCompletionOption.ResponseHeadersRead, - cancellationToken - ).ConfigureAwait(false); + using var response = await _httpClient.SendAsync( + request, + HttpCompletionOption.ResponseHeadersRead, + cancellationToken + ).ConfigureAwait(false); - response.EnsureSuccessStatusCode(); + response.EnsureSuccessStatusCode(); - using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); - await foreach (SseItem sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) - { - switch (sseEvent.EventType) - { - case "endpoint": - HandleEndpointEvent(sseEvent.Data); - break; - - case "message": - await ProcessSseMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false); - break; - } - } - } - catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) - { - _logger.TransportReadMessagesCancelled(_endpointName); - // Normal shutdown - } - catch (IOException) when (cancellationToken.IsCancellationRequested) - { - _logger.TransportReadMessagesCancelled(_endpointName); - // Normal shutdown - } - catch (Exception ex) when (!cancellationToken.IsCancellationRequested) + await foreach (SseItem sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) { - _logger.TransportConnectionError(_endpointName, ex); - - reconnectAttempts++; - if (reconnectAttempts >= _options.MaxReconnectAttempts) + switch (sseEvent.EventType) { - throw new McpTransportException("Exceeded reconnect limit", ex); - } + case "endpoint": + HandleEndpointEvent(sseEvent.Data); + break; - await Task.Delay(_options.ReconnectDelay, cancellationToken).ConfigureAwait(false); + case "message": + await ProcessSseMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false); + break; + } } } - - SetConnected(false); + catch when (cancellationToken.IsCancellationRequested) + { + // Normal shutdown + _connectionEstablished.TrySetCanceled(cancellationToken); + _logger.TransportReadMessagesCancelled(_endpointName); + } + catch (Exception ex) when (!cancellationToken.IsCancellationRequested) + { + _connectionEstablished.TrySetException(ex); + _logger.TransportConnectionError(_endpointName, ex); + throw; + } + finally + { + SetConnected(false); + } } private async Task ProcessSseMessage(string data, CancellationToken cancellationToken) @@ -306,4 +287,18 @@ private void HandleEndpointEvent(string data) throw new McpTransportException("Failed to parse endpoint event", ex); } } + + private void CopyAdditionalHeaders(HttpRequestHeaders headers) + { + if (_options.AdditionalHeaders is not null) + { + foreach (var header in _options.AdditionalHeaders) + { + if (!headers.TryAddWithoutValidation(header.Key, header.Value)) + { + throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}."); + } + } + } + } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs index aa3942a8f..0a36a15f9 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs @@ -48,37 +48,6 @@ public required Uri Endpoint /// public TimeSpan ConnectionTimeout { get; init; } = TimeSpan.FromSeconds(30); - /// - /// Gets or sets the maximum number of reconnection attempts for the SSE connection before giving up. - /// - /// - /// - /// This property controls how many times the client will attempt to reconnect to the SSE server - /// after a connection failure occurs. If all reconnection attempts fail, a - /// with the message "Exceeded reconnect limit" will be thrown. - /// - /// - /// Between each reconnection attempt, the client will wait for the duration specified by . - /// - /// - public int MaxReconnectAttempts { get; init; } = 3; - - /// - /// Gets or sets the delay to employ between reconnection attempts when the SSE connection fails. - /// - /// - /// - /// When a connection to the SSE server is lost or fails, the client will wait for this duration - /// before attempting to reconnect. This helps prevent excessive reconnection attempts in quick succession - /// which could overload the server or network. - /// - /// - /// The reconnection process continues until either a successful connection is established or - /// the maximum number of reconnection attempts () is reached. - /// - /// - public TimeSpan ReconnectDelay { get; init; } = TimeSpan.FromSeconds(5); - /// /// Gets custom HTTP headers to include in requests to the SSE server. /// diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 0baaf5163..8cb5cef1f 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -23,12 +23,11 @@ public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : Kestr Name = "In-memory Test Server", }; - private Task ConnectMcpClient(HttpClient httpClient, McpClientOptions? clientOptions = null) + private Task ConnectMcpClient(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null) => McpClientFactory.CreateAsync( - new SseClientTransport(DefaultTransportOptions, httpClient, LoggerFactory), - clientOptions, - LoggerFactory, - TestContext.Current.CancellationToken); + new SseClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory), + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); [Fact] public async Task ConnectAndReceiveMessage_InMemoryServer() @@ -38,7 +37,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectMcpClient(HttpClient); + await using var mcpClient = await ConnectMcpClient(); // Send a test message through POST endpoint await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken); @@ -53,7 +52,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU MapAbsoluteEndpointUriMcp(app); await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectMcpClient(HttpClient); + await using var mcpClient = await ConnectMcpClient(); // Send a test message through POST endpoint await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken); @@ -85,7 +84,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectMcpClient(HttpClient); + await using var mcpClient = await ConnectMcpClient(); mcpClient.RegisterNotificationHandler("test/notification", (args, ca) => { @@ -109,14 +108,14 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes() Builder.Services.AddMcpServer(options => { - Interlocked.Increment(ref firstOptionsCallbackCallCount); + firstOptionsCallbackCallCount++; }) .WithHttpTransport() .WithTools(); Builder.Services.AddMcpServer(options => { - Interlocked.Increment(ref secondOptionsCallbackCallCount); + secondOptionsCallbackCallCount++; }) .WithTools(); @@ -125,7 +124,7 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectMcpClient(HttpClient); + await using var mcpClient = await ConnectMcpClient(); // Options can be lazily initialized, but they must be instantiated by the time an MCP client can finish connecting. // Callbacks can be called multiple times if configureOptionsAsync is configured, because that uses the IOptionsFactory, @@ -151,6 +150,78 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes() Assert.Equal("hello from client!", textContent.Text); } + [Fact] + public async Task AdditionalHeaders_AreSent_InGetAndPostRequests() + { + Builder.Services.AddMcpServer() + .WithHttpTransport(); + + await using var app = Builder.Build(); + + bool wasGetRequest = false; + bool wasPostRequest = false; + + app.Use(next => + { + return async context => + { + Assert.Equal("Bearer testToken", context.Request.Headers["Authorize"]); + if (context.Request.Method == HttpMethods.Get) + { + wasGetRequest = true; + } + else if (context.Request.Method == HttpMethods.Post) + { + wasPostRequest = true; + } + await next(context); + }; + }); + + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + var sseOptions = new SseClientTransportOptions() + { + Endpoint = new Uri("http://localhost/sse"), + Name = "In-memory Test Server", + AdditionalHeaders = new() + { + ["Authorize"] = "Bearer testToken" + }, + }; + + await using var mcpClient = await ConnectMcpClient(transportOptions: sseOptions); + + Assert.True(wasGetRequest); + Assert.True(wasPostRequest); + } + + [Fact] + public async Task EmptyAdditionalHeadersKey_Throws_InvalidOpearionException() + { + Builder.Services.AddMcpServer() + .WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + var sseOptions = new SseClientTransportOptions() + { + Endpoint = new Uri("http://localhost/sse"), + Name = "In-memory Test Server", + AdditionalHeaders = new() + { + [""] = "" + }, + }; + + var ex = await Assert.ThrowsAsync(() => ConnectMcpClient(transportOptions: sseOptions)); + Assert.Equal("Failed to add header '' with value '' from AdditionalHeaders.", ex.Message); + } + private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints) { var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index db753c6b1..6cfdd06cb 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -17,8 +17,6 @@ public SseClientTransportTests(ITestOutputHelper testOutputHelper) { Endpoint = new Uri("http://localhost:8080"), ConnectionTimeout = TimeSpan.FromSeconds(2), - MaxReconnectAttempts = 3, - ReconnectDelay = TimeSpan.FromMilliseconds(50), Name = "Test Server", AdditionalHeaders = new Dictionary { @@ -76,15 +74,12 @@ public async Task ConnectAsync_Throws_Exception_On_Failure() mockHttpHandler.RequestHandler = (request) => { retries++; - throw new InvalidOperationException("Test exception"); + throw new Exception("Test exception"); }; - var action = async () => await transport.ConnectAsync(); - - var exception = await Assert.ThrowsAsync(action); - Assert.Equal("Exceeded reconnect limit", exception.Message); - - Assert.Equal(_transportOptions.MaxReconnectAttempts, retries); + var exception = await Assert.ThrowsAsync(() => transport.ConnectAsync(TestContext.Current.CancellationToken)); + Assert.Equal("Test exception", exception.Message); + Assert.Equal(1, retries); } [Fact]