Skip to content

Commit 1de26c7

Browse files
committed
Manually wrap HttpClient rather than use DelegatingHandler for client auth
1 parent cb52e85 commit 1de26c7

File tree

8 files changed

+102
-138
lines changed

8 files changed

+102
-138
lines changed

samples/ProtectedMCPClient/Program.cs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System.Net;
66
using System.Text;
77
using System.Web;
8-
using static System.Runtime.InteropServices.JavaScript.JSType;
98

109
Console.WriteLine("Protected MCP Weather Server");
1110
Console.WriteLine();
@@ -39,17 +38,12 @@
3938

4039
try
4140
{
42-
var transportOptions = new SseClientTransportOptions
41+
var transport = new SseClientTransport(new()
4342
{
4443
Endpoint = new Uri(serverUrl),
45-
Name = "Secure Weather Client"
46-
};
47-
48-
// Create a transport with authentication support using the correct constructor parameters
49-
var transport = new SseClientTransport(
50-
transportOptions,
51-
tokenProvider
52-
);
44+
Name = "Secure Weather Client",
45+
CredentialProvider = tokenProvider,
46+
});
5347

5448
var client = await McpClientFactory.CreateAsync(transport);
5549

src/ModelContextProtocol.Core/Authentication/AuthorizationDelegatingHandler.cs renamed to src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,35 @@
1+
using ModelContextProtocol.Client;
2+
using ModelContextProtocol.Protocol;
13
using System.Net.Http.Headers;
24

35
namespace ModelContextProtocol.Authentication;
46

57
/// <summary>
68
/// A delegating handler that adds authentication tokens to requests and handles 401 responses.
79
/// </summary>
8-
public class AuthorizationDelegatingHandler : DelegatingHandler
10+
internal sealed class AuthenticatingMcpHttpClient(HttpClient httpClient, IMcpCredentialProvider credentialProvider) : McpHttpClient(httpClient)
911
{
10-
private readonly IMcpCredentialProvider _credentialProvider;
11-
private string _currentScheme;
12-
private static readonly char[] SchemeSplitDelimiters = { ' ', ',' };
12+
private static readonly char[] SchemeSplitDelimiters = [' ', ','];
1313

14-
/// <summary>
15-
/// Initializes a new instance of the <see cref="AuthorizationDelegatingHandler"/> class.
16-
/// </summary>
17-
/// <param name="credentialProvider">The provider that supplies authentication tokens.</param>
18-
public AuthorizationDelegatingHandler(IMcpCredentialProvider credentialProvider)
19-
{
20-
Throw.IfNull(credentialProvider);
21-
22-
_credentialProvider = credentialProvider;
23-
24-
// Select first supported scheme as the default
25-
_currentScheme = _credentialProvider.SupportedSchemes.FirstOrDefault() ??
26-
throw new ArgumentException("Authorization provider must support at least one authentication scheme.", nameof(credentialProvider));
27-
}
14+
// Select first supported scheme as the default
15+
private string _currentScheme = credentialProvider.SupportedSchemes.FirstOrDefault() ??
16+
throw new ArgumentException("Authorization provider must support at least one authentication scheme.", nameof(credentialProvider));
2817

2918
/// <summary>
3019
/// Sends an HTTP request with authentication handling.
3120
/// </summary>
32-
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
21+
internal override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, JsonRpcMessage? message, CancellationToken cancellationToken)
3322
{
3423
if (request.Headers.Authorization == null)
3524
{
3625
await AddAuthorizationHeaderAsync(request, _currentScheme, cancellationToken).ConfigureAwait(false);
3726
}
3827

39-
var response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
28+
var response = await base.SendAsync(request, message, cancellationToken).ConfigureAwait(false);
4029

4130
if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized)
4231
{
43-
return await HandleUnauthorizedResponseAsync(request, response, cancellationToken).ConfigureAwait(false);
32+
return await HandleUnauthorizedResponseAsync(request, message, response, cancellationToken).ConfigureAwait(false);
4433
}
4534

4635
return response;
@@ -51,6 +40,7 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
5140
/// </summary>
5241
private async Task<HttpResponseMessage> HandleUnauthorizedResponseAsync(
5342
HttpRequestMessage originalRequest,
43+
JsonRpcMessage? originalJsonRpcMessage,
5444
HttpResponseMessage response,
5545
CancellationToken cancellationToken)
5646
{
@@ -64,14 +54,14 @@ private async Task<HttpResponseMessage> HandleUnauthorizedResponseAsync(
6454
string schemeUsed = originalRequest.Headers.Authorization?.Scheme ?? _currentScheme ?? string.Empty;
6555
if (!string.IsNullOrEmpty(schemeUsed) &&
6656
serverSchemes.Contains(schemeUsed) &&
67-
_credentialProvider.SupportedSchemes.Contains(schemeUsed))
57+
credentialProvider.SupportedSchemes.Contains(schemeUsed))
6858
{
6959
bestSchemeMatch = schemeUsed;
7060
}
7161
else
7262
{
7363
// Find the first server scheme that's in our supported set
74-
bestSchemeMatch = serverSchemes.Intersect(_credentialProvider.SupportedSchemes, StringComparer.OrdinalIgnoreCase).FirstOrDefault();
64+
bestSchemeMatch = serverSchemes.Intersect(credentialProvider.SupportedSchemes, StringComparer.OrdinalIgnoreCase).FirstOrDefault();
7565

7666
// If no match was found, either throw an exception or use default
7767
if (bestSchemeMatch is null)
@@ -81,20 +71,20 @@ private async Task<HttpResponseMessage> HandleUnauthorizedResponseAsync(
8171
throw new IOException(
8272
$"The server does not support any of the provided authentication schemes." +
8373
$"Server supports: [{string.Join(", ", serverSchemes)}], " +
84-
$"Provider supports: [{string.Join(", ", _credentialProvider.SupportedSchemes)}].");
74+
$"Provider supports: [{string.Join(", ", credentialProvider.SupportedSchemes)}].");
8575
}
8676

8777
// If the server didn't specify any schemes, use the provider's default
88-
bestSchemeMatch = _credentialProvider.SupportedSchemes.FirstOrDefault();
78+
bestSchemeMatch = credentialProvider.SupportedSchemes.FirstOrDefault();
8979
}
9080
}
91-
// If we have a scheme to try, use it
81+
9282
if (bestSchemeMatch != null)
9383
{
9484
try
9585
{
9686
// Try to handle the 401 response with the selected scheme
97-
var (handled, recommendedScheme) = await _credentialProvider.HandleUnauthorizedResponseAsync(
87+
var (handled, recommendedScheme) = await credentialProvider.HandleUnauthorizedResponseAsync(
9888
response,
9989
bestSchemeMatch,
10090
cancellationToken).ConfigureAwait(false);
@@ -108,14 +98,8 @@ private async Task<HttpResponseMessage> HandleUnauthorizedResponseAsync(
10898

10999
_currentScheme = recommendedScheme ?? bestSchemeMatch;
110100
}
111-
catch (McpException)
112-
{
113-
// Re-throw McpExceptions as-is to preserve the original error information
114-
throw;
115-
}
116101
catch (Exception ex)
117102
{
118-
// Wrap other exceptions with additional context while preserving the original exception
119103
throw new McpException(
120104
$"Failed to handle unauthorized response with scheme '{bestSchemeMatch}'. " +
121105
"The authentication provider encountered an error while processing the authentication challenge.",
@@ -128,7 +112,6 @@ private async Task<HttpResponseMessage> HandleUnauthorizedResponseAsync(
128112
#if NET
129113
VersionPolicy = originalRequest.VersionPolicy,
130114
#endif
131-
Content = originalRequest.Content
132115
};
133116

134117
// Copy headers except Authorization which we'll set separately
@@ -139,23 +122,10 @@ private async Task<HttpResponseMessage> HandleUnauthorizedResponseAsync(
139122
retryRequest.Headers.TryAddWithoutValidation(header.Key, header.Value);
140123
}
141124
}
142-
#if NET
143-
foreach (var property in originalRequest.Options)
144-
{
145-
retryRequest.Options.Set(new HttpRequestOptionsKey<object?>(property.Key), property.Value);
146-
}
147-
#else
148-
foreach (var property in originalRequest.Properties)
149-
{
150-
retryRequest.Properties.Add(property);
151-
}
152-
#endif
153125

154-
// Add the new authorization header
155126
await AddAuthorizationHeaderAsync(retryRequest, _currentScheme, cancellationToken).ConfigureAwait(false);
156127

157-
// Send the retry request
158-
return await base.SendAsync(retryRequest, cancellationToken).ConfigureAwait(false);
128+
return await base.SendAsync(retryRequest, originalJsonRpcMessage, cancellationToken).ConfigureAwait(false);
159129
}
160130

161131
return response; // Return the original response if we couldn't handle it
@@ -189,7 +159,7 @@ private async Task AddAuthorizationHeaderAsync(HttpRequestMessage request, strin
189159
{
190160
if (request.RequestUri != null)
191161
{
192-
var token = await _credentialProvider.GetCredentialAsync(scheme, request.RequestUri, cancellationToken).ConfigureAwait(false);
162+
var token = await credentialProvider.GetCredentialAsync(scheme, request.RequestUri, cancellationToken).ConfigureAwait(false);
193163
if (!string.IsNullOrEmpty(token))
194164
{
195165
request.Headers.Authorization = new AuthenticationHeaderValue(scheme, token);

src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ namespace ModelContextProtocol.Client;
1313
internal sealed partial class AutoDetectingClientSessionTransport : ITransport
1414
{
1515
private readonly SseClientTransportOptions _options;
16-
private readonly HttpClient _httpClient;
16+
private readonly McpHttpClient _httpClient;
1717
private readonly ILoggerFactory? _loggerFactory;
1818
private readonly ILogger _logger;
1919
private readonly string _name;
2020
private readonly Channel<JsonRpcMessage> _messageChannel;
2121

22-
public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName)
22+
public AutoDetectingClientSessionTransport(string endpointName, SseClientTransportOptions transportOptions, McpHttpClient httpClient, ILoggerFactory? loggerFactory)
2323
{
2424
Throw.IfNull(transportOptions);
2525
Throw.IfNull(httpClient);
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using ModelContextProtocol.Protocol;
2+
using System.Diagnostics;
3+
4+
#if NET
5+
using System.Net.Http.Json;
6+
#else
7+
using System.Text;
8+
using System.Text.Json;
9+
#endif
10+
11+
namespace ModelContextProtocol.Client;
12+
13+
internal class McpHttpClient(HttpClient httpClient)
14+
{
15+
internal virtual async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, JsonRpcMessage? message, CancellationToken cancellationToken)
16+
{
17+
Debug.Assert(request.Content is null, "The request body should only be supplied as a JsonRpcMessage");
18+
Debug.Assert(message is null || request.Method == HttpMethod.Post, "All messages should be sent in POST requests.");
19+
20+
using var content = CreatePostBodyContent(message);
21+
request.Content = content;
22+
return await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
23+
}
24+
25+
private HttpContent? CreatePostBodyContent(JsonRpcMessage? message)
26+
{
27+
if (message is null)
28+
{
29+
return null;
30+
}
31+
32+
#if NET
33+
return JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
34+
#else
35+
return new StringContent(
36+
JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage),
37+
Encoding.UTF8,
38+
"application/json; charset=utf-8"
39+
);
40+
#endif
41+
}
42+
}

src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace ModelContextProtocol.Client;
1515
/// </summary>
1616
internal sealed partial class SseClientSessionTransport : TransportBase
1717
{
18-
private readonly HttpClient _httpClient;
18+
private readonly McpHttpClient _httpClient;
1919
private readonly SseClientTransportOptions _options;
2020
private readonly Uri _sseEndpoint;
2121
private Uri? _messageEndpoint;
@@ -31,7 +31,7 @@ internal sealed partial class SseClientSessionTransport : TransportBase
3131
public SseClientSessionTransport(
3232
string endpointName,
3333
SseClientTransportOptions transportOptions,
34-
HttpClient httpClient,
34+
McpHttpClient httpClient,
3535
Channel<JsonRpcMessage>? messageChannel,
3636
ILoggerFactory? loggerFactory)
3737
: base(endpointName, messageChannel, loggerFactory)
@@ -74,25 +74,16 @@ public override async Task SendMessageAsync(
7474
if (_messageEndpoint == null)
7575
throw new InvalidOperationException("Transport not connected");
7676

77-
using var content = new StringContent(
78-
JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage),
79-
Encoding.UTF8,
80-
"application/json"
81-
);
82-
8377
string messageId = "(no id)";
8478

8579
if (message is JsonRpcMessageWithId messageWithId)
8680
{
8781
messageId = messageWithId.Id.ToString();
8882
}
8983

90-
using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint)
91-
{
92-
Content = content,
93-
};
84+
using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint);
9485
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null);
95-
var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
86+
var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false);
9687

9788
if (!response.IsSuccessStatusCode)
9889
{
@@ -154,11 +145,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
154145
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
155146
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null);
156147

157-
using var response = await _httpClient.SendAsync(
158-
request,
159-
HttpCompletionOption.ResponseHeadersRead,
160-
cancellationToken
161-
).ConfigureAwait(false);
148+
using var response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false);
162149

163150
response.EnsureSuccessStatusCode();
164151

0 commit comments

Comments
 (0)