Skip to content

Commit e5befdc

Browse files
authored
Backport 3.1.x Fix | Throttling of token requests by calling AcquireTokenSilent (#1926)
1 parent 947b7fc commit e5befdc

File tree

1 file changed

+140
-74
lines changed

1 file changed

+140
-74
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

Lines changed: 140 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
using System;
66
using System.Collections.Concurrent;
7+
using System.Linq;
8+
using System.Runtime.Caching;
79
using System.Security;
10+
using System.Security.Cryptography;
11+
using System.Text;
812
using System.Threading;
913
using System.Threading.Tasks;
1014
using Azure.Core;
@@ -24,6 +28,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2428
/// </summary>
2529
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
2630
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
31+
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
32+
private static readonly int s_accountPwCacheTtlInHours = 2;
2733
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
2834
private static readonly string s_defaultScopeSuffix = "/.default";
2935
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
@@ -133,7 +139,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
133139
VisualStudioTenantId = tenantId,
134140
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
135141
};
136-
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
142+
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
137143
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
138144
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
139145
}
@@ -142,15 +148,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
142148

143149
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI)
144150
{
145-
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
151+
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
146152
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
147153
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
148154
}
149155

150-
AuthenticationResult result;
156+
AuthenticationResult result = null;
151157
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
152158
{
153-
AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
159+
AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
154160
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
155161
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
156162
}
@@ -183,83 +189,86 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
183189

184190
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
185191
{
186-
if (!string.IsNullOrEmpty(parameters.UserId))
187-
{
188-
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
189-
.WithCorrelationId(parameters.ConnectionId)
190-
.WithUsername(parameters.UserId)
191-
.ExecuteAsync(cancellationToken: cts.Token);
192-
}
193-
else
194-
{
195-
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
196-
.WithCorrelationId(parameters.ConnectionId)
197-
.ExecuteAsync(cancellationToken: cts.Token);
198-
}
199-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
200-
}
201-
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
202-
{
203-
SecureString password = new SecureString();
204-
foreach (char c in parameters.Password)
205-
password.AppendChar(c);
206-
password.MakeReadOnly();
207-
208-
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
209-
.WithCorrelationId(parameters.ConnectionId)
210-
.ExecuteAsync(cancellationToken: cts.Token);
211-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
212-
}
213-
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
214-
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
215-
{
216-
// Fetch available accounts from 'app' instance
217-
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync()).GetEnumerator();
218-
219-
IAccount account = default;
220-
if (accounts.MoveNext())
192+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
193+
194+
if (null == result)
221195
{
222196
if (!string.IsNullOrEmpty(parameters.UserId))
223197
{
224-
do
225-
{
226-
IAccount currentVal = accounts.Current;
227-
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
228-
{
229-
account = currentVal;
230-
break;
231-
}
232-
}
233-
while (accounts.MoveNext());
198+
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
199+
.WithCorrelationId(parameters.ConnectionId)
200+
.WithUsername(parameters.UserId)
201+
.ExecuteAsync(cancellationToken: cts.Token)
202+
.ConfigureAwait(false);
234203
}
235204
else
236205
{
237-
account = accounts.Current;
206+
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
207+
.WithCorrelationId(parameters.ConnectionId)
208+
.ExecuteAsync(cancellationToken: cts.Token)
209+
.ConfigureAwait(false);
238210
}
211+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
212+
}
213+
}
214+
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
215+
{
216+
string pwCacheKey = GetAccountPwCacheKey(parameters);
217+
object previousPw = s_accountPwCache.Get(pwCacheKey);
218+
byte[] currPwHash = GetHash(parameters.Password);
219+
220+
if (null != previousPw &&
221+
previousPw is byte[] previousPwBytes &&
222+
// Only get the cached token if the current password hash matches the previously used password hash
223+
currPwHash.SequenceEqual(previousPwBytes))
224+
{
225+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
239226
}
240227

241-
if (null != account)
228+
if (null == result)
242229
{
243-
try
244-
{
245-
// If 'account' is available in 'app', we use the same to acquire token silently.
246-
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
247-
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token);
248-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
249-
}
250-
catch (MsalUiRequiredException)
230+
SecureString password = new SecureString();
231+
foreach (char c in parameters.Password)
232+
password.AppendChar(c);
233+
password.MakeReadOnly();
234+
235+
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
236+
.WithCorrelationId(parameters.ConnectionId)
237+
.ExecuteAsync(cancellationToken: cts.Token)
238+
.ConfigureAwait(false);
239+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
240+
241+
// We cache the password hash to ensure future connection requests include a validated password
242+
// when we check for a cached MSAL account. Otherwise, a connection request with the same username
243+
// against the same tenant could succeed with an invalid password when we re-use the cached token.
244+
if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)))
251245
{
252-
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
253-
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
254-
// or the user needs to perform two factor authentication.
255-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
256-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
246+
s_accountPwCache.Remove(pwCacheKey);
247+
s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours));
257248
}
258249
}
259-
else
250+
}
251+
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
252+
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
253+
{
254+
try
255+
{
256+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
257+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
258+
}
259+
catch (MsalUiRequiredException)
260+
{
261+
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
262+
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
263+
// or the user needs to perform two factor authentication.
264+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false);
265+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
266+
}
267+
268+
if (null == result)
260269
{
261270
// If no existing 'account' is found, we request user to sign in interactively.
262-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
271+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false);
263272
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
264273
}
265274
}
@@ -272,8 +281,49 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
272281
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
273282
}
274283

275-
private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
276-
SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts)
284+
private static async Task<AuthenticationResult> TryAcquireTokenSilent(IPublicClientApplication app, SqlAuthenticationParameters parameters,
285+
string[] scopes, CancellationTokenSource cts)
286+
{
287+
AuthenticationResult result = null;
288+
289+
// Fetch available accounts from 'app' instance
290+
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
291+
292+
IAccount account = default;
293+
if (accounts.MoveNext())
294+
{
295+
if (!string.IsNullOrEmpty(parameters.UserId))
296+
{
297+
do
298+
{
299+
IAccount currentVal = accounts.Current;
300+
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
301+
{
302+
account = currentVal;
303+
break;
304+
}
305+
}
306+
while (accounts.MoveNext());
307+
}
308+
else
309+
{
310+
account = accounts.Current;
311+
}
312+
}
313+
314+
if (null != account)
315+
{
316+
// If 'account' is available in 'app', we use the same to acquire token silently.
317+
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
318+
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
319+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
320+
}
321+
322+
return result;
323+
}
324+
325+
private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
326+
SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func<DeviceCodeResult, Task> deviceCodeFlowCallback)
277327
{
278328
try
279329
{
@@ -292,13 +342,14 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
292342
*/
293343
ctsInteractive.CancelAfter(180000);
294344
#endif
295-
if (_customWebUI != null)
345+
if (customWebUI != null)
296346
{
297347
return await app.AcquireTokenInteractive(scopes)
298348
.WithCorrelationId(connectionId)
299-
.WithCustomWebUi(_customWebUI)
349+
.WithCustomWebUi(customWebUI)
300350
.WithLoginHint(userId)
301-
.ExecuteAsync(ctsInteractive.Token);
351+
.ExecuteAsync(ctsInteractive.Token)
352+
.ConfigureAwait(false);
302353
}
303354
else
304355
{
@@ -322,15 +373,17 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
322373
return await app.AcquireTokenInteractive(scopes)
323374
.WithCorrelationId(connectionId)
324375
.WithLoginHint(userId)
325-
.ExecuteAsync(ctsInteractive.Token);
376+
.ExecuteAsync(ctsInteractive.Token)
377+
.ConfigureAwait(false);
326378
}
327379
}
328380
else
329381
{
330382
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
331-
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult))
383+
deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult))
332384
.WithCorrelationId(connectionId)
333-
.ExecuteAsync(cancellationToken: cts.Token);
385+
.ExecuteAsync(cancellationToken: cts.Token)
386+
.ConfigureAwait(false);
334387
return result;
335388
}
336389
}
@@ -380,6 +433,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
380433
return clientApplicationInstance;
381434
}
382435

436+
private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
437+
{
438+
return parameters.Authority + "+" + parameters.UserId;
439+
}
440+
441+
private static byte[] GetHash(string input)
442+
{
443+
byte[] unhashedBytes = Encoding.Unicode.GetBytes(input);
444+
SHA256 sha256 = SHA256.Create();
445+
byte[] hashedBytes = sha256.ComputeHash(unhashedBytes);
446+
return hashedBytes;
447+
}
448+
383449
private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
384450
{
385451
IPublicClientApplication publicClientApplication;

0 commit comments

Comments
 (0)