@@ -18,22 +18,23 @@ namespace Microsoft.NET.Build.Containers;
18
18
/// </summary>
19
19
public partial class AuthHandshakeMessageHandler : DelegatingHandler
20
20
{
21
- private record AuthInfo ( Uri Realm , string Service , string Scope ) ;
21
+ private record AuthInfo ( Uri Realm , string Service , string ? Scope ) ;
22
22
23
23
/// <summary>
24
- /// Cache of most-recently-recieved token for each server.
24
+ /// Cache of most-recently-received token for each server.
25
25
/// </summary>
26
- private static ConcurrentDictionary < string , string > TokenCache = new ( ) ;
26
+ private static ConcurrentDictionary < string , AuthenticationHeaderValue > HostAuthenticationCache = new ( ) ;
27
27
28
28
/// <summary>
29
29
/// the www-authenticate header must have realm, service, and scope information, so this method parses it into that shape if present
30
30
/// </summary>
31
31
/// <param name="msg"></param>
32
32
/// <param name="authInfo"></param>
33
33
/// <returns></returns>
34
- private static bool TryParseAuthenticationInfo ( HttpResponseMessage msg , [ NotNullWhen ( true ) ] out AuthInfo ? authInfo )
34
+ private static bool TryParseAuthenticationInfo ( HttpResponseMessage msg , [ NotNullWhen ( true ) ] out string ? scheme , [ NotNullWhen ( true ) ] out AuthInfo ? authInfo )
35
35
{
36
36
authInfo = null ;
37
+ scheme = null ;
37
38
38
39
var authenticateHeader = msg . Headers . WwwAuthenticate ;
39
40
if ( ! authenticateHeader . Any ( ) )
@@ -42,18 +43,19 @@ private static bool TryParseAuthenticationInfo(HttpResponseMessage msg, [NotNull
42
43
}
43
44
44
45
AuthenticationHeaderValue header = authenticateHeader . First ( ) ;
45
- if ( header is { Scheme : "Bearer" , Parameter : string args } )
46
+ if ( header is { Scheme : "Bearer" or "Basic" , Parameter : string bearerArgs } )
46
47
{
47
-
48
+ scheme = header . Scheme ;
48
49
Dictionary < string , string > keyValues = new ( ) ;
49
-
50
- foreach ( Match match in BearerParameterSplitter ( ) . Matches ( args ) )
50
+ foreach ( Match match in BearerParameterSplitter ( ) . Matches ( bearerArgs ) )
51
51
{
52
52
keyValues . Add ( match . Groups [ "key" ] . Value , match . Groups [ "value" ] . Value ) ;
53
53
}
54
54
55
- if ( keyValues . TryGetValue ( "realm" , out string ? realm ) && keyValues . TryGetValue ( "service" , out string ? service ) && keyValues . TryGetValue ( "scope" , out string ? scope ) )
55
+ if ( keyValues . TryGetValue ( "realm" , out string ? realm ) && keyValues . TryGetValue ( "service" , out string ? service ) )
56
56
{
57
+ string ? scope = null ;
58
+ keyValues . TryGetValue ( "scope" , out scope ) ;
57
59
authInfo = new AuthInfo ( new Uri ( realm ) , service , scope ) ;
58
60
return true ;
59
61
}
@@ -85,14 +87,14 @@ private record TokenResponse(string? token, string? access_token, int? expires_i
85
87
/// <param name="scope"></param>
86
88
/// <param name="cancellationToken"></param>
87
89
/// <returns></returns>
88
- private async Task < string > GetTokenAsync ( Uri realm , string service , string scope , CancellationToken cancellationToken )
90
+ private async Task < AuthenticationHeaderValue ? > GetAuthenticationAsync ( string scheme , Uri realm , string service , string ? scope , CancellationToken cancellationToken )
89
91
{
90
92
// Allow overrides for auth via environment variables
91
93
string ? credU = Environment . GetEnvironmentVariable ( ContainerHelpers . HostObjectUser ) ;
92
94
string ? credP = Environment . GetEnvironmentVariable ( ContainerHelpers . HostObjectPass ) ;
93
95
94
96
// fetch creds for the host
95
- DockerCredentials ? privateRepoCreds ;
97
+ DockerCredentials ? privateRepoCreds ;
96
98
97
99
if ( ! string . IsNullOrEmpty ( credU ) && ! string . IsNullOrEmpty ( credP ) )
98
100
{
@@ -109,33 +111,46 @@ private async Task<string> GetTokenAsync(Uri realm, string service, string scope
109
111
throw new CredentialRetrievalException ( realm . Host , e ) ;
110
112
}
111
113
}
112
-
113
- // use those creds when calling the token provider
114
- var header = privateRepoCreds . Username == "<token>"
115
- ? new AuthenticationHeaderValue ( "Bearer" , privateRepoCreds . Password )
116
- : new AuthenticationHeaderValue ( "Basic" , Convert . ToBase64String ( Encoding . ASCII . GetBytes ( $ "{ privateRepoCreds . Username } :{ privateRepoCreds . Password } ") ) ) ;
117
- var builder = new UriBuilder ( realm ) ;
118
- var queryDict = System . Web . HttpUtility . ParseQueryString ( "" ) ;
119
- queryDict [ "service" ] = service ;
120
- queryDict [ "scope" ] = scope ;
121
- builder . Query = queryDict . ToString ( ) ;
122
- var message = new HttpRequestMessage ( HttpMethod . Get , builder . ToString ( ) ) ;
123
- message . Headers . Authorization = header ;
124
-
125
- var tokenResponse = await base . SendAsync ( message , cancellationToken ) ;
126
- tokenResponse . EnsureSuccessStatusCode ( ) ;
127
-
128
- TokenResponse ? token = JsonSerializer . Deserialize < TokenResponse > ( tokenResponse . Content . ReadAsStream ( ) ) ;
129
- if ( token is null )
114
+
115
+ if ( scheme is "Basic" )
130
116
{
131
- throw new ArgumentException ( "Could not deserialize token from JSON" ) ;
117
+ var basicAuth = new AuthenticationHeaderValue ( "Basic" , Convert . ToBase64String ( Encoding . ASCII . GetBytes ( $ "{ privateRepoCreds . Username } :{ privateRepoCreds . Password } ") ) ) ;
118
+ return HostAuthenticationCache . AddOrUpdate ( realm . Host , basicAuth , ( previous , current ) => current ) ;
132
119
}
120
+ else if ( scheme is "Bearer" )
121
+ {
122
+ // use those creds when calling the token provider
123
+ var header = privateRepoCreds . Username == "<token>"
124
+ ? new AuthenticationHeaderValue ( "Bearer" , privateRepoCreds . Password )
125
+ : new AuthenticationHeaderValue ( "Basic" , Convert . ToBase64String ( Encoding . ASCII . GetBytes ( $ "{ privateRepoCreds . Username } :{ privateRepoCreds . Password } ") ) ) ;
126
+ var builder = new UriBuilder ( realm ) ;
127
+ var queryDict = System . Web . HttpUtility . ParseQueryString ( "" ) ;
128
+ queryDict [ "service" ] = service ;
129
+ if ( scope is string s )
130
+ {
131
+ queryDict [ "scope" ] = s ;
132
+ }
133
+ builder . Query = queryDict . ToString ( ) ;
134
+ var message = new HttpRequestMessage ( HttpMethod . Get , builder . ToString ( ) ) ;
135
+ message . Headers . Authorization = header ;
136
+
137
+ var tokenResponse = await base . SendAsync ( message , cancellationToken ) ;
138
+ tokenResponse . EnsureSuccessStatusCode ( ) ;
133
139
134
- // save the retrieved token in the cache.
135
- // if we encounter a previous token (very possible due to concurrent upload)
136
- // use the more recent token.
137
- TokenCache . AddOrUpdate ( realm . Host , token . ResolvedToken , ( previous , current ) => current ) ;
138
- return token . ResolvedToken ;
140
+ TokenResponse ? token = JsonSerializer . Deserialize < TokenResponse > ( tokenResponse . Content . ReadAsStream ( ) ) ;
141
+ if ( token is null )
142
+ {
143
+ throw new ArgumentException ( "Could not deserialize token from JSON" ) ;
144
+ }
145
+
146
+ // save the retrieved token in the cache
147
+ var bearerAuth = new AuthenticationHeaderValue ( "Bearer" , token . ResolvedToken ) ;
148
+ return HostAuthenticationCache . AddOrUpdate ( realm . Host , bearerAuth , ( previous , current ) => current ) ;
149
+ }
150
+ else
151
+ {
152
+ return null ;
153
+ }
139
154
}
140
155
141
156
protected override async Task < HttpResponseMessage > SendAsync ( HttpRequestMessage request , CancellationToken cancellationToken )
@@ -146,21 +161,21 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
146
161
}
147
162
148
163
// attempt to use cached token for the request if available
149
- if ( TokenCache . TryGetValue ( request . RequestUri . Host , out string ? cachedToken ) )
164
+ if ( HostAuthenticationCache . TryGetValue ( request . RequestUri . Host , out AuthenticationHeaderValue ? cachedAuthentication ) )
150
165
{
151
- request . Headers . Authorization = new AuthenticationHeaderValue ( "Bearer" , cachedToken ) ;
166
+ request . Headers . Authorization = cachedAuthentication ;
152
167
}
153
168
154
169
var response = await base . SendAsync ( request , cancellationToken ) ;
155
170
if ( response is { StatusCode : HttpStatusCode . OK } )
156
171
{
157
172
return response ;
158
173
}
159
- else if ( response is { StatusCode : HttpStatusCode . Unauthorized } && TryParseAuthenticationInfo ( response , out AuthInfo ? authInfo ) )
174
+ else if ( response is { StatusCode : HttpStatusCode . Unauthorized } && TryParseAuthenticationInfo ( response , out string ? scheme , out AuthInfo ? authInfo ) )
160
175
{
161
- if ( await GetTokenAsync ( authInfo . Realm , authInfo . Service , authInfo . Scope , cancellationToken ) is string fetchedToken )
176
+ if ( await GetAuthenticationAsync ( scheme , authInfo . Realm , authInfo . Service , authInfo . Scope , cancellationToken ) is AuthenticationHeaderValue authentication )
162
177
{
163
- request . Headers . Authorization = new AuthenticationHeaderValue ( "Bearer" , fetchedToken ) ;
178
+ request . Headers . Authorization = authentication ;
164
179
return await base . SendAsync ( request , cancellationToken ) ;
165
180
}
166
181
return response ;
0 commit comments