Skip to content

Commit 9bf4ea3

Browse files
committed
Amend middleware logic
1 parent 3fd7681 commit 9bf4ea3

File tree

6 files changed

+220
-38
lines changed

6 files changed

+220
-38
lines changed

src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ public async Task InvokeAsync(
4747
return;
4848
}
4949

50-
// Handle the PRM document endpoint
51-
if (context.Request.Path.StartsWithSegments("/.well-known/oauth-protected-resource"))
50+
// Handle the PRM document endpoint if not handled by the endpoint
51+
if (context.Request.Path.StartsWithSegments("/.well-known/oauth-protected-resource") &&
52+
context.GetEndpoint() == null)
5253
{
5354
_logger.LogDebug("Serving Protected Resource Metadata document");
5455
context.Response.ContentType = "application/json";
@@ -59,40 +60,8 @@ await JsonSerializer.SerializeAsync(
5960
return;
6061
}
6162

62-
// Serve SSE and message endpoints with authorization
63-
if (context.Request.Path.StartsWithSegments("/sse") ||
64-
(context.Request.Path.Value?.EndsWith("/message") == true))
65-
{
66-
// Check if the Authorization header is present
67-
if (!context.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader))
68-
{
69-
// No Authorization header present, return 401 Unauthorized
70-
var prm = authProvider.GetProtectedResourceMetadata();
71-
var prmUrl = GetPrmUrl(context, prm.Resource);
72-
73-
_logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header");
74-
context.Response.StatusCode = StatusCodes.Status401Unauthorized;
75-
context.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\"");
76-
return;
77-
}
78-
79-
// Validate the token - ensuring authHeader is a non-null string
80-
string authHeaderValue = authHeader.ToString();
81-
bool isValid = await authProvider.ValidateTokenAsync(authHeaderValue);
82-
if (!isValid)
83-
{
84-
// Invalid token, return 401 Unauthorized
85-
var prm = authProvider.GetProtectedResourceMetadata();
86-
var prmUrl = GetPrmUrl(context, prm.Resource);
87-
88-
_logger.LogDebug("Invalid authorization token, returning 401 Unauthorized");
89-
context.Response.StatusCode = StatusCodes.Status401Unauthorized;
90-
context.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\"");
91-
return;
92-
}
93-
}
94-
95-
// Token is valid or endpoint doesn't require authentication, proceed to the next middleware
63+
// Proceed to the next middleware - authorization for SSE and message endpoints
64+
// is now handled by endpoint filters
9665
await _next(context);
9766
}
9867

src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ public static class HttpMcpServerBuilderExtensions
1818
/// <param name="configureOptions">Configures options for the Streamable HTTP transport. This allows configuring per-session
1919
/// <see cref="McpServerOptions"/> and running logic before and after a session.</param>
2020
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
21-
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
21+
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
2222
public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder, Action<HttpServerTransportOptions>? configureOptions = null)
2323
{
2424
ArgumentNullException.ThrowIfNull(builder);
2525
builder.Services.TryAddSingleton<StreamableHttpHandler>();
2626
builder.Services.TryAddSingleton<SseHandler>();
27+
builder.Services.TryAddSingleton<McpAuthorizationFilterFactory>();
2728
builder.Services.AddHostedService<IdleTrackingBackgroundService>();
2829

2930
if (configureOptions is not null)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using Microsoft.AspNetCore.Http;
2+
using Microsoft.Extensions.Logging;
3+
using ModelContextProtocol.Protocol.Auth;
4+
5+
namespace ModelContextProtocol.AspNetCore;
6+
7+
/// <summary>
8+
/// An endpoint filter that handles authorization for MCP endpoints.
9+
/// </summary>
10+
internal class McpAuthorizationFilter : IEndpointFilter
11+
{
12+
private readonly ILogger<McpAuthorizationFilter> _logger;
13+
private readonly IServerAuthorizationProvider _authProvider;
14+
15+
/// <summary>
16+
/// Initializes a new instance of the <see cref="McpAuthorizationFilter"/> class.
17+
/// </summary>
18+
/// <param name="logger">The logger.</param>
19+
/// <param name="authProvider">The authorization provider.</param>
20+
public McpAuthorizationFilter(
21+
ILogger<McpAuthorizationFilter> logger,
22+
IServerAuthorizationProvider authProvider)
23+
{
24+
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
25+
_authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider));
26+
}
27+
28+
/// <inheritdoc/>
29+
public async ValueTask<object?> InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next)
30+
{
31+
var httpContext = context.HttpContext;
32+
33+
// Check if the Authorization header is present
34+
if (!httpContext.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader))
35+
{
36+
// No Authorization header present, return 401 Unauthorized
37+
var prm = _authProvider.GetProtectedResourceMetadata();
38+
var prmUrl = GetPrmUrl(httpContext, prm.Resource);
39+
40+
_logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header");
41+
httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
42+
httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\"");
43+
return Results.Empty;
44+
}
45+
46+
// Validate the token - ensuring authHeader is a non-null string
47+
string authHeaderValue = authHeader.ToString();
48+
bool isValid = await _authProvider.ValidateTokenAsync(authHeaderValue);
49+
if (!isValid)
50+
{
51+
// Invalid token, return 401 Unauthorized
52+
var prm = _authProvider.GetProtectedResourceMetadata();
53+
var prmUrl = GetPrmUrl(httpContext, prm.Resource);
54+
55+
_logger.LogDebug("Invalid authorization token, returning 401 Unauthorized");
56+
httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
57+
httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\"");
58+
return Results.Empty;
59+
}
60+
61+
// Token is valid, proceed to the next filter
62+
return await next(context);
63+
}
64+
65+
private static string GetPrmUrl(HttpContext context, string resourceUri)
66+
{
67+
// Use the actual resource URI from PRM if it's an absolute URL, otherwise build the URL
68+
if (Uri.TryCreate(resourceUri, UriKind.Absolute, out _))
69+
{
70+
return $"{resourceUri.TrimEnd('/')}/.well-known/oauth-protected-resource";
71+
}
72+
73+
// Build the URL from the current request
74+
var request = context.Request;
75+
var scheme = request.Scheme;
76+
var host = request.Host.Value;
77+
return $"{scheme}://{host}/.well-known/oauth-protected-resource";
78+
}
79+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using Microsoft.AspNetCore.Http;
2+
using Microsoft.Extensions.DependencyInjection;
3+
using Microsoft.Extensions.Logging;
4+
using ModelContextProtocol.Protocol.Auth;
5+
6+
namespace ModelContextProtocol.AspNetCore;
7+
8+
/// <summary>
9+
/// Factory for creating <see cref="McpAuthorizationFilter"/> instances.
10+
/// </summary>
11+
internal class McpAuthorizationFilterFactory
12+
{
13+
private readonly IServiceProvider _serviceProvider;
14+
15+
/// <summary>
16+
/// Initializes a new instance of the <see cref="McpAuthorizationFilterFactory"/> class.
17+
/// </summary>
18+
/// <param name="serviceProvider">The service provider.</param>
19+
public McpAuthorizationFilterFactory(IServiceProvider serviceProvider)
20+
{
21+
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
22+
}
23+
24+
/// <summary>
25+
/// Creates an endpoint filter delegate for authorization.
26+
/// </summary>
27+
/// <param name="context">The endpoint filter factory context.</param>
28+
/// <param name="next">The next filter delegate in the pipeline.</param>
29+
/// <returns>The filter delegate.</returns>
30+
public EndpointFilterDelegate Create(EndpointFilterFactoryContext context, EndpointFilterDelegate next)
31+
{
32+
// This factory creates a filter that checks if the current endpoint is an SSE or message endpoint
33+
// and applies authorization only to those endpoints
34+
return async invocationContext =>
35+
{
36+
var httpContext = invocationContext.HttpContext;
37+
var path = httpContext.Request.Path.Value?.TrimEnd('/');
38+
39+
// Only apply authorization to /sse and /message endpoints
40+
if (path != null && (path.EndsWith("/sse") || path.EndsWith("/message")))
41+
{
42+
var authProvider = _serviceProvider.GetRequiredService<IServerAuthorizationProvider>();
43+
var logger = _serviceProvider.GetRequiredService<ILogger<McpAuthorizationFilter>>();
44+
45+
var filter = new McpAuthorizationFilter(logger, authProvider);
46+
return await filter.InvokeAsync(invocationContext, next);
47+
}
48+
49+
// For all other endpoints, just invoke the next filter
50+
return await next(invocationContext);
51+
};
52+
}
53+
}

src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
using Microsoft.AspNetCore.Http.Metadata;
33
using Microsoft.AspNetCore.Routing;
44
using Microsoft.Extensions.DependencyInjection;
5+
using Microsoft.Extensions.Logging;
56
using ModelContextProtocol.AspNetCore;
7+
using ModelContextProtocol.Protocol.Auth;
68
using ModelContextProtocol.Protocol.Messages;
79
using System.Diagnostics.CodeAnalysis;
810

@@ -20,12 +22,36 @@ public static class McpEndpointRouteBuilderExtensions
2022
/// </summary>
2123
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
2224
/// <param name="pattern">The route pattern prefix to map to.</param>
23-
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
25+
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
2426
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern = "")
2527
{
2628
var streamableHttpHandler = endpoints.ServiceProvider.GetService<StreamableHttpHandler>() ??
2729
throw new InvalidOperationException("You must call WithHttpTransport(). Unable to find required services. Call builder.Services.AddMcpServer().WithHttpTransport() in application startup code.");
2830

31+
// Map the protected resource metadata endpoint if authorization is configured
32+
var authProvider = endpoints.ServiceProvider.GetService<IServerAuthorizationProvider>();
33+
if (authProvider != null)
34+
{
35+
// Create and register the ProtectedResourceMetadataHandler if it's not already registered
36+
ProtectedResourceMetadataHandler? prmHandler = null;
37+
try
38+
{
39+
prmHandler = endpoints.ServiceProvider.GetService<ProtectedResourceMetadataHandler>();
40+
}
41+
catch
42+
{
43+
// Ignore - we'll create it below
44+
}
45+
46+
if (prmHandler == null)
47+
{
48+
var logger = endpoints.ServiceProvider.GetRequiredService<ILogger<ProtectedResourceMetadataHandler>>();
49+
prmHandler = new ProtectedResourceMetadataHandler(logger, authProvider);
50+
}
51+
52+
endpoints.MapGet("/.well-known/oauth-protected-resource", prmHandler.HandleAsync);
53+
}
54+
2955
var mcpGroup = endpoints.MapGroup(pattern);
3056
var streamableHttpGroup = mcpGroup.MapGroup("")
3157
.WithDisplayName(b => $"MCP Streamable HTTP | {b.DisplayName}")
@@ -44,6 +70,16 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
4470
var sseGroup = mcpGroup.MapGroup("")
4571
.WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}");
4672

73+
// Apply authorization filter to SSE endpoints if authorization is configured
74+
if (authProvider != null)
75+
{
76+
// Create the filter factory
77+
var filterFactory = endpoints.ServiceProvider.GetRequiredService<McpAuthorizationFilterFactory>();
78+
79+
// Apply filter to SSE and message endpoints
80+
sseGroup.AddEndpointFilterFactory(filterFactory.Create);
81+
}
82+
4783
sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync)
4884
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]));
4985
sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using Microsoft.AspNetCore.Http;
2+
using Microsoft.Extensions.Logging;
3+
using ModelContextProtocol.Protocol.Auth;
4+
using ModelContextProtocol.Utils.Json;
5+
using System.Text.Json;
6+
7+
namespace ModelContextProtocol.AspNetCore;
8+
9+
/// <summary>
10+
/// Handler for the Protected Resource Metadata document endpoint.
11+
/// </summary>
12+
internal class ProtectedResourceMetadataHandler
13+
{
14+
private readonly ILogger<ProtectedResourceMetadataHandler> _logger;
15+
private readonly IServerAuthorizationProvider _authProvider;
16+
17+
/// <summary>
18+
/// Initializes a new instance of the <see cref="ProtectedResourceMetadataHandler"/> class.
19+
/// </summary>
20+
/// <param name="logger">The logger.</param>
21+
/// <param name="authProvider">The authorization provider.</param>
22+
public ProtectedResourceMetadataHandler(
23+
ILogger<ProtectedResourceMetadataHandler> logger,
24+
IServerAuthorizationProvider authProvider)
25+
{
26+
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
27+
_authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider));
28+
}
29+
30+
/// <summary>
31+
/// Handles the request for the Protected Resource Metadata document.
32+
/// </summary>
33+
/// <param name="context">The HTTP context.</param>
34+
/// <returns>A task that represents the asynchronous operation.</returns>
35+
public async Task HandleAsync(HttpContext context)
36+
{
37+
_logger.LogDebug("Serving Protected Resource Metadata document");
38+
context.Response.ContentType = "application/json";
39+
await JsonSerializer.SerializeAsync(
40+
context.Response.Body,
41+
_authProvider.GetProtectedResourceMetadata(),
42+
McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata)));
43+
}
44+
}

0 commit comments

Comments
 (0)