44
55using System ;
66using System . Collections . Concurrent ;
7- using System . Security ;
7+ using System . Linq ;
8+ using System . Runtime . Caching ;
9+ using System . Security . Cryptography ;
10+ using System . Text ;
811using System . Threading ;
912using System . Threading . Tasks ;
1013using Azure . Core ;
@@ -24,6 +27,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2427 /// </summary>
2528 private static ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap
2629 = new ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > ( ) ;
30+ private static readonly MemoryCache s_accountPwCache = new ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
31+ private static readonly int s_accountPwCacheTtlInHours = 2 ;
2732 private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient" ;
2833 private static readonly string s_defaultScopeSuffix = "/.default" ;
2934 private readonly string _type = typeof ( ActiveDirectoryAuthenticationProvider ) . Name ;
@@ -172,7 +177,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
172177 return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
173178 }
174179
175- AuthenticationResult result ;
180+ AuthenticationResult result = null ;
176181 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryServicePrincipal )
177182 {
178183 AccessToken accessToken = await new ClientSecretCredential ( audience , parameters . UserId , parameters . Password , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
@@ -208,82 +213,82 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
208213
209214 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryIntegrated )
210215 {
211- if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
212- {
213- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
214- . WithCorrelationId ( parameters . ConnectionId )
215- . WithUsername ( parameters . UserId )
216- . ExecuteAsync ( cancellationToken : cts . Token )
217- . ConfigureAwait ( false ) ;
218- }
219- else
220- {
221- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
222- . WithCorrelationId ( parameters . ConnectionId )
223- . ExecuteAsync ( cancellationToken : cts . Token )
224- . ConfigureAwait ( false ) ;
225- }
226- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
227- }
228- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
229- {
230- result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , parameters . Password )
231- . WithCorrelationId ( parameters . ConnectionId )
232- . ExecuteAsync ( cancellationToken : cts . Token )
233- . ConfigureAwait ( false ) ;
234-
235- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
236- }
237- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
238- parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
239- {
240- // Fetch available accounts from 'app' instance
241- System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
216+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
242217
243- IAccount account = default ;
244- if ( accounts . MoveNext ( ) )
218+ if ( null == result )
245219 {
246220 if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
247221 {
248- do
249- {
250- IAccount currentVal = accounts . Current ;
251- if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
252- {
253- account = currentVal ;
254- break ;
255- }
256- }
257- while ( accounts . MoveNext ( ) ) ;
222+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
223+ . WithCorrelationId ( parameters . ConnectionId )
224+ . WithUsername ( parameters . UserId )
225+ . ExecuteAsync ( cancellationToken : cts . Token )
226+ . ConfigureAwait ( false ) ;
258227 }
259228 else
260229 {
261- account = accounts . Current ;
230+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
231+ . WithCorrelationId ( parameters . ConnectionId )
232+ . ExecuteAsync ( cancellationToken : cts . Token )
233+ . ConfigureAwait ( false ) ;
262234 }
235+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
236+ }
237+ }
238+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
239+ {
240+ string pwCacheKey = GetAccountPwCacheKey ( parameters ) ;
241+ object previousPw = s_accountPwCache . Get ( pwCacheKey ) ;
242+ byte [ ] currPwHash = GetHash ( parameters . Password ) ;
243+
244+ if ( null != previousPw &&
245+ previousPw is byte [ ] previousPwBytes &&
246+ // Only get the cached token if the current password hash matches the previously used password hash
247+ currPwHash . SequenceEqual ( previousPwBytes ) )
248+ {
249+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
263250 }
264251
265- if ( null != account )
252+ if ( null == result )
266253 {
267- try
268- {
269- // If 'account' is available in 'app', we use the same to acquire token silently.
270- // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
271- result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
272- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
273- }
274- catch ( MsalUiRequiredException )
254+ result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , parameters . Password )
255+ . WithCorrelationId ( parameters . ConnectionId )
256+ . ExecuteAsync ( cancellationToken : cts . Token )
257+ . ConfigureAwait ( false ) ;
258+
259+ // We cache the password hash to ensure future connection requests include a validated password
260+ // when we check for a cached MSAL account. Otherwise, a connection request with the same username
261+ // against the same tenant could succeed with an invalid password when we re-use the cached token.
262+ if ( ! s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) )
275263 {
276- // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
277- // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
278- // or the user needs to perform two factor authentication.
279- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) . ConfigureAwait ( false ) ;
280- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
264+ s_accountPwCache . Remove ( pwCacheKey ) ;
265+ s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) ;
281266 }
267+
268+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
282269 }
283- else
270+ }
271+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
272+ parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
273+ {
274+ try
275+ {
276+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
277+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
278+ }
279+ catch ( MsalUiRequiredException )
280+ {
281+ // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
282+ // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
283+ // or the user needs to perform two factor authentication.
284+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
285+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
286+ }
287+
288+ if ( null == result )
284289 {
285290 // If no existing 'account' is found, we request user to sign in interactively.
286- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) . ConfigureAwait ( false ) ;
291+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
287292 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
288293 }
289294 }
@@ -296,8 +301,49 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
296301 return new SqlAuthenticationToken ( result . AccessToken , result . ExpiresOn ) ;
297302 }
298303
299- private async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
300- SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts )
304+ private static async Task < AuthenticationResult > TryAcquireTokenSilent ( IPublicClientApplication app , SqlAuthenticationParameters parameters ,
305+ string [ ] scopes , CancellationTokenSource cts )
306+ {
307+ AuthenticationResult result = null ;
308+
309+ // Fetch available accounts from 'app' instance
310+ System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
311+
312+ IAccount account = default ;
313+ if ( accounts . MoveNext ( ) )
314+ {
315+ if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
316+ {
317+ do
318+ {
319+ IAccount currentVal = accounts . Current ;
320+ if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
321+ {
322+ account = currentVal ;
323+ break ;
324+ }
325+ }
326+ while ( accounts . MoveNext ( ) ) ;
327+ }
328+ else
329+ {
330+ account = accounts . Current ;
331+ }
332+ }
333+
334+ if ( null != account )
335+ {
336+ // If 'account' is available in 'app', we use the same to acquire token silently.
337+ // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
338+ result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
339+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
340+ }
341+
342+ return result ;
343+ }
344+
345+ private static async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
346+ SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts , ICustomWebUi customWebUI , Func < DeviceCodeResult , Task > deviceCodeFlowCallback )
301347 {
302348 try
303349 {
@@ -316,11 +362,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
316362 */
317363 ctsInteractive . CancelAfter ( 180000 ) ;
318364#endif
319- if ( _customWebUI != null )
365+ if ( customWebUI != null )
320366 {
321367 return await app . AcquireTokenInteractive ( scopes )
322368 . WithCorrelationId ( connectionId )
323- . WithCustomWebUi ( _customWebUI )
369+ . WithCustomWebUi ( customWebUI )
324370 . WithLoginHint ( userId )
325371 . ExecuteAsync ( ctsInteractive . Token )
326372 . ConfigureAwait ( false ) ;
@@ -354,7 +400,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
354400 else
355401 {
356402 AuthenticationResult result = await app . AcquireTokenWithDeviceCode ( scopes ,
357- deviceCodeResult => _deviceCodeFlowCallback ( deviceCodeResult ) )
403+ deviceCodeResult => deviceCodeFlowCallback ( deviceCodeResult ) )
358404 . WithCorrelationId ( connectionId )
359405 . ExecuteAsync ( cancellationToken : cts . Token )
360406 . ConfigureAwait ( false ) ;
@@ -407,6 +453,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
407453 return clientApplicationInstance ;
408454 }
409455
456+ private static string GetAccountPwCacheKey ( SqlAuthenticationParameters parameters )
457+ {
458+ return parameters . Authority + "+" + parameters . UserId ;
459+ }
460+
461+ private static byte [ ] GetHash ( string input )
462+ {
463+ byte [ ] unhashedBytes = Encoding . Unicode . GetBytes ( input ) ;
464+ SHA256 sha256 = SHA256 . Create ( ) ;
465+ byte [ ] hashedBytes = sha256 . ComputeHash ( unhashedBytes ) ;
466+ return hashedBytes ;
467+ }
468+
410469 private IPublicClientApplication CreateClientAppInstance ( PublicClientAppKey publicClientAppKey )
411470 {
412471 IPublicClientApplication publicClientApplication ;
0 commit comments