44
55using System ;
66using System . Collections . Concurrent ;
7+ using System . Linq ;
8+ using System . Runtime . Caching ;
79using System . Security ;
10+ using System . Security . Cryptography ;
11+ using System . Text ;
812using System . Threading ;
913using System . Threading . Tasks ;
1014using Azure . Core ;
@@ -24,6 +28,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2428 /// </summary>
2529 private static ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap
2630 = new ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > ( ) ;
31+ private static readonly MemoryCache s_accountPwCache = new ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
32+ private static readonly int s_accountPwCacheTtlInHours = 2 ;
2733 private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient" ;
2834 private static readonly string s_defaultScopeSuffix = "/.default" ;
2935 private readonly string _type = typeof ( ActiveDirectoryAuthenticationProvider ) . Name ;
@@ -171,7 +177,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
171177 return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
172178 }
173179
174- AuthenticationResult result ;
180+ AuthenticationResult result = null ;
175181 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryServicePrincipal )
176182 {
177183 AccessToken accessToken = await new ClientSecretCredential ( audience , parameters . UserId , parameters . Password , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
@@ -207,86 +213,87 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
207213
208214 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryIntegrated )
209215 {
210- if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
211- {
212- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
213- . WithCorrelationId ( parameters . ConnectionId )
214- . WithUsername ( parameters . UserId )
215- . ExecuteAsync ( cancellationToken : cts . Token )
216- . ConfigureAwait ( false ) ;
217- }
218- else
219- {
220- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
221- . WithCorrelationId ( parameters . ConnectionId )
222- . ExecuteAsync ( cancellationToken : cts . Token )
223- . ConfigureAwait ( false ) ;
224- }
225- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
226- }
227- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
228- {
229- SecureString password = new SecureString ( ) ;
230- foreach ( char c in parameters . Password )
231- password . AppendChar ( c ) ;
232- password . MakeReadOnly ( ) ;
233-
234- result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , password )
235- . WithCorrelationId ( parameters . ConnectionId )
236- . ExecuteAsync ( cancellationToken : cts . Token )
237- . ConfigureAwait ( false ) ;
238- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
239- }
240- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
241- parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
242- {
243- // Fetch available accounts from 'app' instance
244- System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
216+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
245217
246- IAccount account = default ;
247- if ( accounts . MoveNext ( ) )
218+ if ( null == result )
248219 {
249220 if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
250221 {
251- do
252- {
253- IAccount currentVal = accounts . Current ;
254- if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
255- {
256- account = currentVal ;
257- break ;
258- }
259- }
260- while ( accounts . MoveNext ( ) ) ;
222+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
223+ . WithCorrelationId ( parameters . ConnectionId )
224+ . WithUsername ( parameters . UserId )
225+ . ExecuteAsync ( cancellationToken : cts . Token )
226+ . ConfigureAwait ( false ) ;
261227 }
262228 else
263229 {
264- account = accounts . Current ;
230+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
231+ . WithCorrelationId ( parameters . ConnectionId )
232+ . ExecuteAsync ( cancellationToken : cts . Token )
233+ . ConfigureAwait ( false ) ;
265234 }
235+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
236+ }
237+ }
238+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
239+ {
240+ string pwCacheKey = GetAccountPwCacheKey ( parameters ) ;
241+ object previousPw = s_accountPwCache . Get ( pwCacheKey ) ;
242+ byte [ ] currPwHash = GetHash ( parameters . Password ) ;
243+
244+ if ( null != previousPw &&
245+ previousPw is byte [ ] previousPwBytes &&
246+ // Only get the cached token if the current password hash matches the previously used password hash
247+ currPwHash . SequenceEqual ( previousPwBytes ) )
248+ {
249+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
266250 }
267251
268- if ( null != account )
252+ if ( null == result )
269253 {
270- try
254+ SecureString password = new SecureString ( ) ;
255+ foreach ( char c in parameters . Password )
256+ password . AppendChar ( c ) ;
257+ password . MakeReadOnly ( ) ;
258+
259+ result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , password )
260+ . WithCorrelationId ( parameters . ConnectionId )
261+ . ExecuteAsync ( cancellationToken : cts . Token )
262+ . ConfigureAwait ( false ) ;
263+
264+ // We cache the password hash to ensure future connection requests include a validated password
265+ // when we check for a cached MSAL account. Otherwise, a connection request with the same username
266+ // against the same tenant could succeed with an invalid password when we re-use the cached token.
267+ if ( ! s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) )
271268 {
272- // If 'account' is available in 'app', we use the same to acquire token silently.
273- // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
274- result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
275- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
276- }
277- catch ( MsalUiRequiredException )
278- {
279- // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
280- // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
281- // or the user needs to perform two factor authentication.
282- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) . ConfigureAwait ( false ) ;
283- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
269+ s_accountPwCache . Remove ( pwCacheKey ) ;
270+ s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) ;
284271 }
272+
273+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
285274 }
286- else
275+ }
276+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
277+ parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
278+ {
279+ try
280+ {
281+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
282+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
283+ }
284+ catch ( MsalUiRequiredException )
285+ {
286+ // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
287+ // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
288+ // or the user needs to perform two factor authentication.
289+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
290+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
291+ }
292+
293+ if ( null == result )
287294 {
288295 // If no existing 'account' is found, we request user to sign in interactively.
289- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) . ConfigureAwait ( false ) ;
296+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
290297 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
291298 }
292299 }
@@ -299,8 +306,49 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
299306 return new SqlAuthenticationToken ( result . AccessToken , result . ExpiresOn ) ;
300307 }
301308
302- private async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
303- SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts )
309+ private static async Task < AuthenticationResult > TryAcquireTokenSilent ( IPublicClientApplication app , SqlAuthenticationParameters parameters ,
310+ string [ ] scopes , CancellationTokenSource cts )
311+ {
312+ AuthenticationResult result = null ;
313+
314+ // Fetch available accounts from 'app' instance
315+ System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
316+
317+ IAccount account = default ;
318+ if ( accounts . MoveNext ( ) )
319+ {
320+ if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
321+ {
322+ do
323+ {
324+ IAccount currentVal = accounts . Current ;
325+ if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
326+ {
327+ account = currentVal ;
328+ break ;
329+ }
330+ }
331+ while ( accounts . MoveNext ( ) ) ;
332+ }
333+ else
334+ {
335+ account = accounts . Current ;
336+ }
337+ }
338+
339+ if ( null != account )
340+ {
341+ // If 'account' is available in 'app', we use the same to acquire token silently.
342+ // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
343+ result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
344+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
345+ }
346+
347+ return result ;
348+ }
349+
350+ private static async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
351+ SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts , ICustomWebUi customWebUI , Func < DeviceCodeResult , Task > deviceCodeFlowCallback )
304352 {
305353 try
306354 {
@@ -319,11 +367,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
319367 */
320368 ctsInteractive . CancelAfter ( 180000 ) ;
321369#endif
322- if ( _customWebUI != null )
370+ if ( customWebUI != null )
323371 {
324372 return await app . AcquireTokenInteractive ( scopes )
325373 . WithCorrelationId ( connectionId )
326- . WithCustomWebUi ( _customWebUI )
374+ . WithCustomWebUi ( customWebUI )
327375 . WithLoginHint ( userId )
328376 . ExecuteAsync ( ctsInteractive . Token )
329377 . ConfigureAwait ( false ) ;
@@ -357,7 +405,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
357405 else
358406 {
359407 AuthenticationResult result = await app . AcquireTokenWithDeviceCode ( scopes ,
360- deviceCodeResult => _deviceCodeFlowCallback ( deviceCodeResult ) )
408+ deviceCodeResult => deviceCodeFlowCallback ( deviceCodeResult ) )
361409 . WithCorrelationId ( connectionId )
362410 . ExecuteAsync ( cancellationToken : cts . Token )
363411 . ConfigureAwait ( false ) ;
@@ -410,6 +458,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
410458 return clientApplicationInstance ;
411459 }
412460
461+ private static string GetAccountPwCacheKey ( SqlAuthenticationParameters parameters )
462+ {
463+ return parameters . Authority + "+" + parameters . UserId ;
464+ }
465+
466+ private static byte [ ] GetHash ( string input )
467+ {
468+ byte [ ] unhashedBytes = Encoding . Unicode . GetBytes ( input ) ;
469+ SHA256 sha256 = SHA256 . Create ( ) ;
470+ byte [ ] hashedBytes = sha256 . ComputeHash ( unhashedBytes ) ;
471+ return hashedBytes ;
472+ }
473+
413474 private IPublicClientApplication CreateClientAppInstance ( PublicClientAppKey publicClientAppKey )
414475 {
415476 IPublicClientApplication publicClientApplication ;
0 commit comments