Skip to content

Commit 78eda60

Browse files
committed
Enable auth caching by URL
Bring back the cache but instead of by host, use "url except query string" as the key, since that is likely to match a scope reasonably closely. Implemented as a standalone class to make the interface at calling sites easier to understand than the direct ConcurrentDictionary access, plus allow centralized policy decisions.
1 parent 8c81b4c commit 78eda60

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

Microsoft.NET.Build.Containers/AuthHandshakeMessageHandler.cs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,19 @@ private record TokenResponse(string? token, string? access_token, int? expires_i
7777
/// Credentials for the request are retrieved from the credential provider, then used to acquire a token.
7878
/// That token is cached for some duration on a per-host basis.
7979
/// </summary>
80-
/// <param name="realm"></param>
80+
/// <param name="uri"></param>
8181
/// <param name="service"></param>
8282
/// <param name="scope"></param>
8383
/// <param name="cancellationToken"></param>
8484
/// <returns></returns>
85-
private async Task<AuthenticationHeaderValue?> GetAuthenticationAsync(string scheme, Uri realm, string service, string? scope, CancellationToken cancellationToken)
85+
private async Task<AuthenticationHeaderValue?> GetAuthenticationAsync(string scheme, Uri uri, string service, string? scope, CancellationToken cancellationToken)
8686
{
8787
// Allow overrides for auth via environment variables
8888
string? credU = Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectUser);
8989
string? credP = Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectPass);
9090

9191
// fetch creds for the host
92-
DockerCredentials? privateRepoCreds;
92+
DockerCredentials? privateRepoCreds;
9393

9494
if (!string.IsNullOrEmpty(credU) && !string.IsNullOrEmpty(credP))
9595
{
@@ -99,26 +99,26 @@ private record TokenResponse(string? token, string? access_token, int? expires_i
9999
{
100100
try
101101
{
102-
privateRepoCreds = await CredsProvider.GetCredentialsAsync(realm.Host);
102+
privateRepoCreds = await CredsProvider.GetCredentialsAsync(uri.Host);
103103
}
104104
catch (Exception e)
105105
{
106-
throw new CredentialRetrievalException(realm.Host, e);
106+
throw new CredentialRetrievalException(uri.Host, e);
107107
}
108108
}
109109

110110
if (scheme is "Basic")
111111
{
112112
var basicAuth = new AuthenticationHeaderValue("Basic", Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}")));
113-
return basicAuth;
113+
return AuthHeaderCache.AddOrUpdate(uri, basicAuth);
114114
}
115115
else if (scheme is "Bearer")
116116
{
117117
// use those creds when calling the token provider
118118
var header = privateRepoCreds.Username == "<token>"
119119
? new AuthenticationHeaderValue("Bearer", privateRepoCreds.Password)
120120
: new AuthenticationHeaderValue("Basic", Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}")));
121-
var builder = new UriBuilder(realm);
121+
var builder = new UriBuilder(uri);
122122
var queryDict = System.Web.HttpUtility.ParseQueryString("");
123123
queryDict["service"] = service;
124124
if (scope is string s)
@@ -140,7 +140,7 @@ private record TokenResponse(string? token, string? access_token, int? expires_i
140140

141141
// save the retrieved token in the cache
142142
var bearerAuth = new AuthenticationHeaderValue("Bearer", token.ResolvedToken);
143-
return bearerAuth;
143+
return AuthHeaderCache.AddOrUpdate(uri, bearerAuth);
144144
}
145145
else
146146
{
@@ -155,7 +155,11 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
155155
throw new ArgumentException("No RequestUri specified", nameof(request));
156156
}
157157

158-
// TODO: attempt to use cached token for the request if available
158+
// attempt to use cached token for the request if available
159+
if (AuthHeaderCache.TryGet(request.RequestUri, out AuthenticationHeaderValue? cachedAuthentication))
160+
{
161+
request.Headers.Authorization = cachedAuthentication;
162+
}
159163

160164
var response = await base.SendAsync(request, cancellationToken);
161165
if (response is { StatusCode: HttpStatusCode.OK })
@@ -166,7 +170,7 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
166170
{
167171
if (await GetAuthenticationAsync(scheme, authInfo.Realm, authInfo.Service, authInfo.Scope, cancellationToken) is AuthenticationHeaderValue authentication)
168172
{
169-
request.Headers.Authorization = authentication;
173+
request.Headers.Authorization = AuthHeaderCache.AddOrUpdate(request.RequestUri, authentication);
170174
return await base.SendAsync(request, cancellationToken);
171175
}
172176
return response;
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using System;
2+
using System.Collections.Concurrent;
3+
using System.Collections.Generic;
4+
using System.Diagnostics.CodeAnalysis;
5+
using System.Linq;
6+
using System.Net.Http.Headers;
7+
using System.Text;
8+
using System.Threading.Tasks;
9+
10+
namespace Microsoft.NET.Build.Containers;
11+
12+
internal static class AuthHeaderCache
13+
{
14+
15+
private static ConcurrentDictionary<string, AuthenticationHeaderValue> HostAuthenticationCache = new();
16+
17+
public static bool TryGet(Uri uri, [NotNullWhen(true)] out AuthenticationHeaderValue? header)
18+
{
19+
return HostAuthenticationCache.TryGetValue(GetCacheKey(uri), out header);
20+
}
21+
22+
public static AuthenticationHeaderValue AddOrUpdate(Uri uri, AuthenticationHeaderValue header)
23+
{
24+
return HostAuthenticationCache.AddOrUpdate(GetCacheKey(uri), header, (_, _) => header);
25+
}
26+
27+
private static string GetCacheKey(Uri uri)
28+
{
29+
return uri.Host + uri.AbsolutePath;
30+
}
31+
}

0 commit comments

Comments
 (0)