Skip to content

Commit fd60a1c

Browse files
committed
Trim implementation
1 parent 400f191 commit fd60a1c

File tree

6 files changed

+182
-135
lines changed

6 files changed

+182
-135
lines changed

src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder
2525

2626
builder.Services.TryAddSingleton<StreamableHttpHandler>();
2727
builder.Services.TryAddSingleton<SseHandler>();
28-
builder.Services.TryAddSingleton<McpAuthorizationFilterFactory>();
2928
builder.Services.AddHostedService<IdleTrackingBackgroundService>();
3029

3130
if (configureOptions is not null)

src/ModelContextProtocol.AspNetCore/McpAuthorizationFilterFactory.cs

Lines changed: 0 additions & 53 deletions
This file was deleted.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using Microsoft.AspNetCore.Builder;
2+
using Microsoft.AspNetCore.Http;
3+
using Microsoft.Extensions.DependencyInjection;
4+
using Microsoft.Extensions.Logging;
5+
using ModelContextProtocol.Protocol.Auth;
6+
7+
namespace ModelContextProtocol.AspNetCore;
8+
9+
/// <summary>
10+
/// Provides extension methods for adding MCP authorization to endpoints.
11+
/// </summary>
12+
public static class McpEndpointAuthorizationExtensions
13+
{
14+
/// <summary>
15+
/// Adds MCP authorization filter to an endpoint.
16+
/// </summary>
17+
/// <param name="builder">The endpoint convention builder.</param>
18+
/// <param name="authProvider">The authorization provider.</param>
19+
/// <param name="serviceProvider">The service provider.</param>
20+
/// <returns>The builder for chaining.</returns>
21+
public static IEndpointConventionBuilder AddMcpAuthorization(
22+
this IEndpointConventionBuilder builder,
23+
IServerAuthorizationProvider authProvider,
24+
IServiceProvider serviceProvider)
25+
{
26+
if (authProvider == null)
27+
{
28+
return builder; // No authorization needed
29+
}
30+
31+
var logger = serviceProvider.GetRequiredService<ILogger<McpEndpointAuthorizationFilter>>();
32+
var filter = new McpEndpointAuthorizationFilter(logger, authProvider);
33+
34+
return builder.AddEndpointFilter(filter);
35+
}
36+
37+
/// <summary>
38+
/// Adds MCP authorization filter to multiple endpoints.
39+
/// </summary>
40+
/// <param name="endpoints">The collection of endpoint convention builders.</param>
41+
/// <param name="authProvider">The authorization provider.</param>
42+
/// <param name="serviceProvider">The service provider.</param>
43+
/// <returns>The original collection for chaining.</returns>
44+
public static IEnumerable<IEndpointConventionBuilder> AddMcpAuthorization(
45+
this IEnumerable<IEndpointConventionBuilder> endpoints,
46+
IServerAuthorizationProvider authProvider,
47+
IServiceProvider serviceProvider)
48+
{
49+
if (authProvider == null)
50+
{
51+
return endpoints; // No authorization needed
52+
}
53+
54+
var logger = serviceProvider.GetRequiredService<ILogger<McpEndpointAuthorizationFilter>>();
55+
var filter = new McpEndpointAuthorizationFilter(logger, authProvider);
56+
57+
foreach (var endpoint in endpoints)
58+
{
59+
endpoint.AddEndpointFilter(filter);
60+
}
61+
62+
return endpoints;
63+
}
64+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// filepath: c:\Users\ddelimarsky\source\csharp-sdk\src\ModelContextProtocol.AspNetCore\McpEndpointAuthorizationFilter.cs
2+
using Microsoft.AspNetCore.Http;
3+
using Microsoft.Extensions.Logging;
4+
using ModelContextProtocol.Protocol.Auth;
5+
6+
namespace ModelContextProtocol.AspNetCore;
7+
8+
/// <summary>
9+
/// An endpoint filter that handles authorization for MCP endpoints using the standard ASP.NET Core endpoint filter pattern.
10+
/// </summary>
11+
internal class McpEndpointAuthorizationFilter : IEndpointFilter
12+
{
13+
private readonly ILogger _logger;
14+
private readonly IServerAuthorizationProvider _authProvider;
15+
16+
/// <summary>
17+
/// Initializes a new instance of the <see cref="McpEndpointAuthorizationFilter"/> class.
18+
/// </summary>
19+
/// <param name="logger">The logger.</param>
20+
/// <param name="authProvider">The authorization provider.</param>
21+
public McpEndpointAuthorizationFilter(ILogger logger, IServerAuthorizationProvider authProvider)
22+
{
23+
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
24+
_authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider));
25+
}
26+
27+
/// <inheritdoc/>
28+
public async ValueTask<object?> InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next)
29+
{
30+
var httpContext = context.HttpContext;
31+
32+
// Check if the Authorization header is present
33+
if (!httpContext.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader))
34+
{
35+
// No Authorization header present, return 401 Unauthorized
36+
var prm = _authProvider.GetProtectedResourceMetadata();
37+
var prmUrl = GetPrmUrl(httpContext, prm.Resource);
38+
39+
_logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header");
40+
httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
41+
httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\"");
42+
return Results.Empty;
43+
}
44+
45+
// Validate the token
46+
string authHeaderValue = authHeader.ToString();
47+
bool isValid = await _authProvider.ValidateTokenAsync(authHeaderValue);
48+
if (!isValid)
49+
{
50+
// Invalid token, return 401 Unauthorized
51+
var prm = _authProvider.GetProtectedResourceMetadata();
52+
var prmUrl = GetPrmUrl(httpContext, prm.Resource);
53+
54+
_logger.LogDebug("Invalid authorization token, returning 401 Unauthorized");
55+
httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
56+
httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\"");
57+
return Results.Empty;
58+
}
59+
60+
// Token is valid, proceed to the next filter
61+
return await next(context);
62+
}
63+
64+
/// <summary>
65+
/// Builds the URL for the protected resource metadata endpoint.
66+
/// </summary>
67+
/// <param name="context">The HTTP context.</param>
68+
/// <param name="resourceUri">The resource URI from the protected resource metadata.</param>
69+
/// <returns>The full URL to the protected resource metadata endpoint.</returns>
70+
private static string GetPrmUrl(HttpContext context, string resourceUri)
71+
{
72+
// Use the actual resource URI from PRM if it's an absolute URL, otherwise build the URL
73+
if (Uri.TryCreate(resourceUri, UriKind.Absolute, out _))
74+
{
75+
return $"{resourceUri.TrimEnd('/')}/.well-known/oauth-protected-resource";
76+
}
77+
78+
// Build the URL from the current request
79+
var request = context.Request;
80+
var scheme = request.Scheme;
81+
var host = request.Host.Value;
82+
return $"{scheme}://{host}/.well-known/oauth-protected-resource";
83+
}
84+
}

src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,26 +66,23 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
6666
streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync);
6767

6868
// Map legacy HTTP with SSE endpoints.
69-
var sseHandler = endpoints.ServiceProvider.GetRequiredService<SseHandler>();
70-
var sseGroup = mcpGroup.MapGroup("")
69+
var sseHandler = endpoints.ServiceProvider.GetRequiredService<SseHandler>(); var sseGroup = mcpGroup.MapGroup("")
7170
.WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}");
7271

73-
// Apply authorization filter to SSE endpoints if authorization is configured
74-
if (authProvider != null)
75-
{
76-
// Create the filter factory
77-
var filterFactory = endpoints.ServiceProvider.GetRequiredService<McpAuthorizationFilterFactory>();
78-
79-
// Apply filter to SSE and message endpoints
80-
sseGroup.AddEndpointFilterFactory(filterFactory.Create);
81-
}
82-
83-
sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync)
72+
// Configure SSE endpoints
73+
var sseEndpoint = sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync)
8474
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]));
85-
sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync)
75+
76+
var messageEndpoint = sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync)
8677
.WithMetadata(new AcceptsMetadata(["application/json"]))
8778
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted));
8879

80+
// Apply authorization filter directly to SSE endpoints if authorization is configured
81+
if (authProvider != null)
82+
{
83+
// Apply authorization to both endpoints using the extension method
84+
new[] { sseEndpoint, messageEndpoint }.AddMcpAuthorization(authProvider, endpoints.ServiceProvider);
85+
}
8986
return mcpGroup;
9087
}
91-
}
88+
}

src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs

Lines changed: 22 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
using Microsoft.Extensions.Logging;
22
using Microsoft.Extensions.Logging.Abstractions;
3-
using ModelContextProtocol.Protocol.Auth;
43
using ModelContextProtocol.Protocol.Messages;
54
using ModelContextProtocol.Utils;
65
using ModelContextProtocol.Utils.Json;
76
using System.Diagnostics;
8-
using System.Net;
97
using System.Net.Http.Headers;
108
using System.Net.ServerSentEvents;
119
using System.Text;
@@ -26,7 +24,6 @@ internal sealed partial class SseClientSessionTransport : TransportBase
2624
private Task? _receiveTask;
2725
private readonly ILogger _logger;
2826
private readonly TaskCompletionSource<bool> _connectionEstablished;
29-
private readonly IAuthorizationHandler _authorizationHandler;
3027

3128
/// <summary>
3229
/// SSE transport for client endpoints. Unlike stdio it does not launch a process, but connects to an existing server.
@@ -48,18 +45,6 @@ public SseClientSessionTransport(SseClientTransportOptions transportOptions, Htt
4845
_connectionCts = new CancellationTokenSource();
4946
_logger = (ILogger?)loggerFactory?.CreateLogger<SseClientTransport>() ?? NullLogger.Instance;
5047
_connectionEstablished = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
51-
52-
// Initialize the authorization handler
53-
if (transportOptions.AuthorizationOptions?.AuthorizationHandler != null)
54-
{
55-
// Use explicitly provided handler
56-
_authorizationHandler = transportOptions.AuthorizationOptions.AuthorizationHandler;
57-
}
58-
else
59-
{
60-
// Create default handler with auth options
61-
_authorizationHandler = new DefaultAuthorizationHandler(loggerFactory, transportOptions.AuthorizationOptions);
62-
}
6348
}
6449

6550
/// <inheritdoc/>
@@ -89,48 +74,18 @@ public override async Task SendMessageAsync(
8974
if (_messageEndpoint == null)
9075
throw new InvalidOperationException("Transport not connected");
9176

77+
using var content = new StringContent(
78+
JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage),
79+
Encoding.UTF8,
80+
"application/json"
81+
);
82+
9283
string messageId = "(no id)";
9384

9485
if (message is JsonRpcMessageWithId messageWithId)
9586
{
9687
messageId = messageWithId.Id.ToString();
9788
}
98-
99-
// Send the request, handling potential auth challenges
100-
HttpResponseMessage? response = null;
101-
bool authRetry = false;
102-
103-
do
104-
{
105-
authRetry = false;
106-
107-
// Create a new request for each attempt
108-
using var currentRequest = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint);
109-
currentRequest.Content = new StringContent(
110-
JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage),
111-
Encoding.UTF8,
112-
"application/json"
113-
);
114-
115-
// Add authorization headers if needed - the handler will only add headers if auth is required
116-
await _authorizationHandler.AuthenticateRequestAsync(currentRequest).ConfigureAwait(false);
117-
118-
// Copy additional headers
119-
CopyAdditionalHeaders(currentRequest.Headers);
120-
121-
// Dispose previous response before making a new request
122-
response?.Dispose();
123-
124-
response = await _httpClient.SendAsync(currentRequest, cancellationToken).ConfigureAwait(false);
125-
126-
// Handle 401 Unauthorized response - this will only execute if the server requires auth
127-
if (response.StatusCode == HttpStatusCode.Unauthorized)
128-
{
129-
// Try to handle the unauthorized response
130-
authRetry = await _authorizationHandler.HandleUnauthorizedResponseAsync(
131-
response, _messageEndpoint).ConfigureAwait(false);
132-
}
133-
} while (authRetry);
13489

13590
using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint)
13691
{
@@ -139,25 +94,26 @@ public override async Task SendMessageAsync(
13994
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders);
14095
var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
14196

142-
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
97+
response.EnsureSuccessStatusCode();
14398

144-
// Check if the message was an initialize request
145-
if (message is JsonRpcRequest request && request.Method == RequestMethods.Initialize)
146-
{
147-
// If the response is not a JSON-RPC response, it is an SSE message
148-
if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase))
149-
{
150-
LogAcceptedPost(Name, messageId);
151-
// The response will arrive as an SSE message
152-
}
153-
else
154-
{
155-
JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ??
156-
throw new InvalidOperationException("Failed to initialize client");
99+
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
157100

158101
if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase))
159102
{
160-
response.Dispose();
103+
LogAcceptedPost(Name, messageId);
104+
}
105+
else
106+
{
107+
if (_logger.IsEnabled(LogLevel.Trace))
108+
{
109+
LogRejectedPostSensitive(Name, messageId, responseContent);
110+
}
111+
else
112+
{
113+
LogRejectedPost(Name, messageId);
114+
}
115+
116+
throw new InvalidOperationException("Failed to send message");
161117
}
162118
}
163119

0 commit comments

Comments
 (0)