Skip to content

Commit c5cefd2

Browse files
authored
Handle registry authentication requests that use the Basic scheme (#217)
1 parent 4eb4f98 commit c5cefd2

File tree

1 file changed

+55
-40
lines changed

1 file changed

+55
-40
lines changed

Microsoft.NET.Build.Containers/AuthHandshakeMessageHandler.cs

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,23 @@ namespace Microsoft.NET.Build.Containers;
1818
/// </summary>
1919
public partial class AuthHandshakeMessageHandler : DelegatingHandler
2020
{
21-
private record AuthInfo(Uri Realm, string Service, string Scope);
21+
private record AuthInfo(Uri Realm, string Service, string? Scope);
2222

2323
/// <summary>
24-
/// Cache of most-recently-recieved token for each server.
24+
/// Cache of most-recently-received token for each server.
2525
/// </summary>
26-
private static ConcurrentDictionary<string, string> TokenCache = new();
26+
private static ConcurrentDictionary<string, AuthenticationHeaderValue> HostAuthenticationCache = new();
2727

2828
/// <summary>
2929
/// the www-authenticate header must have realm, service, and scope information, so this method parses it into that shape if present
3030
/// </summary>
3131
/// <param name="msg"></param>
3232
/// <param name="authInfo"></param>
3333
/// <returns></returns>
34-
private static bool TryParseAuthenticationInfo(HttpResponseMessage msg, [NotNullWhen(true)] out AuthInfo? authInfo)
34+
private static bool TryParseAuthenticationInfo(HttpResponseMessage msg, [NotNullWhen(true)] out string? scheme, [NotNullWhen(true)] out AuthInfo? authInfo)
3535
{
3636
authInfo = null;
37+
scheme = null;
3738

3839
var authenticateHeader = msg.Headers.WwwAuthenticate;
3940
if (!authenticateHeader.Any())
@@ -42,18 +43,19 @@ private static bool TryParseAuthenticationInfo(HttpResponseMessage msg, [NotNull
4243
}
4344

4445
AuthenticationHeaderValue header = authenticateHeader.First();
45-
if (header is { Scheme: "Bearer", Parameter: string args })
46+
if (header is { Scheme: "Bearer" or "Basic", Parameter: string bearerArgs })
4647
{
47-
48+
scheme = header.Scheme;
4849
Dictionary<string, string> keyValues = new();
49-
50-
foreach (Match match in BearerParameterSplitter().Matches(args))
50+
foreach (Match match in BearerParameterSplitter().Matches(bearerArgs))
5151
{
5252
keyValues.Add(match.Groups["key"].Value, match.Groups["value"].Value);
5353
}
5454

55-
if (keyValues.TryGetValue("realm", out string? realm) && keyValues.TryGetValue("service", out string? service) && keyValues.TryGetValue("scope", out string? scope))
55+
if (keyValues.TryGetValue("realm", out string? realm) && keyValues.TryGetValue("service", out string? service))
5656
{
57+
string? scope = null;
58+
keyValues.TryGetValue("scope", out scope);
5759
authInfo = new AuthInfo(new Uri(realm), service, scope);
5860
return true;
5961
}
@@ -85,14 +87,14 @@ private record TokenResponse(string? token, string? access_token, int? expires_i
8587
/// <param name="scope"></param>
8688
/// <param name="cancellationToken"></param>
8789
/// <returns></returns>
88-
private async Task<string> GetTokenAsync(Uri realm, string service, string scope, CancellationToken cancellationToken)
90+
private async Task<AuthenticationHeaderValue?> GetAuthenticationAsync(string scheme, Uri realm, string service, string? scope, CancellationToken cancellationToken)
8991
{
9092
// Allow overrides for auth via environment variables
9193
string? credU = Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectUser);
9294
string? credP = Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectPass);
9395

9496
// fetch creds for the host
95-
DockerCredentials? privateRepoCreds;
97+
DockerCredentials? privateRepoCreds;
9698

9799
if (!string.IsNullOrEmpty(credU) && !string.IsNullOrEmpty(credP))
98100
{
@@ -109,33 +111,46 @@ private async Task<string> GetTokenAsync(Uri realm, string service, string scope
109111
throw new CredentialRetrievalException(realm.Host, e);
110112
}
111113
}
112-
113-
// use those creds when calling the token provider
114-
var header = privateRepoCreds.Username == "<token>"
115-
? new AuthenticationHeaderValue("Bearer", privateRepoCreds.Password)
116-
: new AuthenticationHeaderValue("Basic", Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}")));
117-
var builder = new UriBuilder(realm);
118-
var queryDict = System.Web.HttpUtility.ParseQueryString("");
119-
queryDict["service"] = service;
120-
queryDict["scope"] = scope;
121-
builder.Query = queryDict.ToString();
122-
var message = new HttpRequestMessage(HttpMethod.Get, builder.ToString());
123-
message.Headers.Authorization = header;
124-
125-
var tokenResponse = await base.SendAsync(message, cancellationToken);
126-
tokenResponse.EnsureSuccessStatusCode();
127-
128-
TokenResponse? token = JsonSerializer.Deserialize<TokenResponse>(tokenResponse.Content.ReadAsStream());
129-
if (token is null)
114+
115+
if (scheme is "Basic")
130116
{
131-
throw new ArgumentException("Could not deserialize token from JSON");
117+
var basicAuth = new AuthenticationHeaderValue("Basic", Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}")));
118+
return HostAuthenticationCache.AddOrUpdate(realm.Host, basicAuth, (previous, current) => current);
132119
}
120+
else if (scheme is "Bearer")
121+
{
122+
// use those creds when calling the token provider
123+
var header = privateRepoCreds.Username == "<token>"
124+
? new AuthenticationHeaderValue("Bearer", privateRepoCreds.Password)
125+
: new AuthenticationHeaderValue("Basic", Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}")));
126+
var builder = new UriBuilder(realm);
127+
var queryDict = System.Web.HttpUtility.ParseQueryString("");
128+
queryDict["service"] = service;
129+
if (scope is string s)
130+
{
131+
queryDict["scope"] = s;
132+
}
133+
builder.Query = queryDict.ToString();
134+
var message = new HttpRequestMessage(HttpMethod.Get, builder.ToString());
135+
message.Headers.Authorization = header;
136+
137+
var tokenResponse = await base.SendAsync(message, cancellationToken);
138+
tokenResponse.EnsureSuccessStatusCode();
133139

134-
// save the retrieved token in the cache.
135-
// if we encounter a previous token (very possible due to concurrent upload)
136-
// use the more recent token.
137-
TokenCache.AddOrUpdate(realm.Host, token.ResolvedToken, (previous, current) => current);
138-
return token.ResolvedToken;
140+
TokenResponse? token = JsonSerializer.Deserialize<TokenResponse>(tokenResponse.Content.ReadAsStream());
141+
if (token is null)
142+
{
143+
throw new ArgumentException("Could not deserialize token from JSON");
144+
}
145+
146+
// save the retrieved token in the cache
147+
var bearerAuth = new AuthenticationHeaderValue("Bearer", token.ResolvedToken);
148+
return HostAuthenticationCache.AddOrUpdate(realm.Host, bearerAuth, (previous, current) => current);
149+
}
150+
else
151+
{
152+
return null;
153+
}
139154
}
140155

141156
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
@@ -146,21 +161,21 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
146161
}
147162

148163
// attempt to use cached token for the request if available
149-
if (TokenCache.TryGetValue(request.RequestUri.Host, out string? cachedToken))
164+
if (HostAuthenticationCache.TryGetValue(request.RequestUri.Host, out AuthenticationHeaderValue? cachedAuthentication))
150165
{
151-
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", cachedToken);
166+
request.Headers.Authorization = cachedAuthentication;
152167
}
153168

154169
var response = await base.SendAsync(request, cancellationToken);
155170
if (response is { StatusCode: HttpStatusCode.OK })
156171
{
157172
return response;
158173
}
159-
else if (response is { StatusCode: HttpStatusCode.Unauthorized } && TryParseAuthenticationInfo(response, out AuthInfo? authInfo))
174+
else if (response is { StatusCode: HttpStatusCode.Unauthorized } && TryParseAuthenticationInfo(response, out string? scheme, out AuthInfo? authInfo))
160175
{
161-
if (await GetTokenAsync(authInfo.Realm, authInfo.Service, authInfo.Scope, cancellationToken) is string fetchedToken)
176+
if (await GetAuthenticationAsync(scheme, authInfo.Realm, authInfo.Service, authInfo.Scope, cancellationToken) is AuthenticationHeaderValue authentication)
162177
{
163-
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", fetchedToken);
178+
request.Headers.Authorization = authentication;
164179
return await base.SendAsync(request, cancellationToken);
165180
}
166181
return response;

0 commit comments

Comments
 (0)