44
55using System ;
66using System . Collections . Concurrent ;
7+ using System . Linq ;
8+ using System . Runtime . Caching ;
79using System . Security ;
10+ using System . Security . Cryptography ;
11+ using System . Text ;
812using System . Threading ;
913using System . Threading . Tasks ;
1014using Azure . Core ;
@@ -24,6 +28,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2428 /// </summary>
2529 private static ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap
2630 = new ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > ( ) ;
31+ private static readonly MemoryCache s_accountPwCache = new ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
32+ private static readonly int s_accountPwCacheTtlInHours = 2 ;
2733 private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient" ;
2834 private static readonly string s_defaultScopeSuffix = "/.default" ;
2935 private readonly string _type = typeof ( ActiveDirectoryAuthenticationProvider ) . Name ;
@@ -133,7 +139,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
133139 VisualStudioTenantId = tenantId ,
134140 ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
135141 } ;
136- AccessToken accessToken = await new DefaultAzureCredential ( defaultAzureCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) ;
142+ AccessToken accessToken = await new DefaultAzureCredential ( defaultAzureCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
137143 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}" , accessToken . ExpiresOn ) ;
138144 return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
139145 }
@@ -142,15 +148,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
142148
143149 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryManagedIdentity || parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryMSI )
144150 {
145- AccessToken accessToken = await new ManagedIdentityCredential ( clientId , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) ;
151+ AccessToken accessToken = await new ManagedIdentityCredential ( clientId , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
146152 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}" , accessToken . ExpiresOn ) ;
147153 return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
148154 }
149155
150- AuthenticationResult result ;
156+ AuthenticationResult result = null ;
151157 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryServicePrincipal )
152158 {
153- AccessToken accessToken = await new ClientSecretCredential ( tenantId , parameters . UserId , parameters . Password , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) ;
159+ AccessToken accessToken = await new ClientSecretCredential ( tenantId , parameters . UserId , parameters . Password , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
154160 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}" , accessToken . ExpiresOn ) ;
155161 return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
156162 }
@@ -183,83 +189,86 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
183189
184190 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryIntegrated )
185191 {
186- if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
187- {
188- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
189- . WithCorrelationId ( parameters . ConnectionId )
190- . WithUsername ( parameters . UserId )
191- . ExecuteAsync ( cancellationToken : cts . Token ) ;
192- }
193- else
194- {
195- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
196- . WithCorrelationId ( parameters . ConnectionId )
197- . ExecuteAsync ( cancellationToken : cts . Token ) ;
198- }
199- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
200- }
201- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
202- {
203- SecureString password = new SecureString ( ) ;
204- foreach ( char c in parameters . Password )
205- password . AppendChar ( c ) ;
206- password . MakeReadOnly ( ) ;
207-
208- result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , password )
209- . WithCorrelationId ( parameters . ConnectionId )
210- . ExecuteAsync ( cancellationToken : cts . Token ) ;
211- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
212- }
213- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
214- parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
215- {
216- // Fetch available accounts from 'app' instance
217- System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) ) . GetEnumerator ( ) ;
218-
219- IAccount account = default ;
220- if ( accounts . MoveNext ( ) )
192+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
193+
194+ if ( null == result )
221195 {
222196 if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
223197 {
224- do
225- {
226- IAccount currentVal = accounts . Current ;
227- if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
228- {
229- account = currentVal ;
230- break ;
231- }
232- }
233- while ( accounts . MoveNext ( ) ) ;
198+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
199+ . WithCorrelationId ( parameters . ConnectionId )
200+ . WithUsername ( parameters . UserId )
201+ . ExecuteAsync ( cancellationToken : cts . Token )
202+ . ConfigureAwait ( false ) ;
234203 }
235204 else
236205 {
237- account = accounts . Current ;
206+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
207+ . WithCorrelationId ( parameters . ConnectionId )
208+ . ExecuteAsync ( cancellationToken : cts . Token )
209+ . ConfigureAwait ( false ) ;
238210 }
211+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
212+ }
213+ }
214+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
215+ {
216+ string pwCacheKey = GetAccountPwCacheKey ( parameters ) ;
217+ object previousPw = s_accountPwCache . Get ( pwCacheKey ) ;
218+ byte [ ] currPwHash = GetHash ( parameters . Password ) ;
219+
220+ if ( null != previousPw &&
221+ previousPw is byte [ ] previousPwBytes &&
222+ // Only get the cached token if the current password hash matches the previously used password hash
223+ currPwHash . SequenceEqual ( previousPwBytes ) )
224+ {
225+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
239226 }
240227
241- if ( null != account )
228+ if ( null == result )
242229 {
243- try
244- {
245- // If 'account' is available in 'app', we use the same to acquire token silently.
246- // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
247- result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) ;
248- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
249- }
250- catch ( MsalUiRequiredException )
230+ SecureString password = new SecureString ( ) ;
231+ foreach ( char c in parameters . Password )
232+ password . AppendChar ( c ) ;
233+ password . MakeReadOnly ( ) ;
234+
235+ result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , password )
236+ . WithCorrelationId ( parameters . ConnectionId )
237+ . ExecuteAsync ( cancellationToken : cts . Token )
238+ . ConfigureAwait ( false ) ;
239+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
240+
241+ // We cache the password hash to ensure future connection requests include a validated password
242+ // when we check for a cached MSAL account. Otherwise, a connection request with the same username
243+ // against the same tenant could succeed with an invalid password when we re-use the cached token.
244+ if ( ! s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) )
251245 {
252- // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
253- // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
254- // or the user needs to perform two factor authentication.
255- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) ;
256- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
246+ s_accountPwCache . Remove ( pwCacheKey ) ;
247+ s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) ;
257248 }
258249 }
259- else
250+ }
251+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
252+ parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
253+ {
254+ try
255+ {
256+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
257+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
258+ }
259+ catch ( MsalUiRequiredException )
260+ {
261+ // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
262+ // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
263+ // or the user needs to perform two factor authentication.
264+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
265+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
266+ }
267+
268+ if ( null == result )
260269 {
261270 // If no existing 'account' is found, we request user to sign in interactively.
262- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) ;
271+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
263272 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
264273 }
265274 }
@@ -272,8 +281,49 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
272281 return new SqlAuthenticationToken ( result . AccessToken , result . ExpiresOn ) ;
273282 }
274283
275- private async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
276- SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts )
284+ private static async Task < AuthenticationResult > TryAcquireTokenSilent ( IPublicClientApplication app , SqlAuthenticationParameters parameters ,
285+ string [ ] scopes , CancellationTokenSource cts )
286+ {
287+ AuthenticationResult result = null ;
288+
289+ // Fetch available accounts from 'app' instance
290+ System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
291+
292+ IAccount account = default ;
293+ if ( accounts . MoveNext ( ) )
294+ {
295+ if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
296+ {
297+ do
298+ {
299+ IAccount currentVal = accounts . Current ;
300+ if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
301+ {
302+ account = currentVal ;
303+ break ;
304+ }
305+ }
306+ while ( accounts . MoveNext ( ) ) ;
307+ }
308+ else
309+ {
310+ account = accounts . Current ;
311+ }
312+ }
313+
314+ if ( null != account )
315+ {
316+ // If 'account' is available in 'app', we use the same to acquire token silently.
317+ // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
318+ result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
319+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
320+ }
321+
322+ return result ;
323+ }
324+
325+ private static async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
326+ SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts , ICustomWebUi customWebUI , Func < DeviceCodeResult , Task > deviceCodeFlowCallback )
277327 {
278328 try
279329 {
@@ -292,13 +342,14 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
292342 */
293343 ctsInteractive . CancelAfter ( 180000 ) ;
294344#endif
295- if ( _customWebUI != null )
345+ if ( customWebUI != null )
296346 {
297347 return await app . AcquireTokenInteractive ( scopes )
298348 . WithCorrelationId ( connectionId )
299- . WithCustomWebUi ( _customWebUI )
349+ . WithCustomWebUi ( customWebUI )
300350 . WithLoginHint ( userId )
301- . ExecuteAsync ( ctsInteractive . Token ) ;
351+ . ExecuteAsync ( ctsInteractive . Token )
352+ . ConfigureAwait ( false ) ;
302353 }
303354 else
304355 {
@@ -322,15 +373,17 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
322373 return await app . AcquireTokenInteractive ( scopes )
323374 . WithCorrelationId ( connectionId )
324375 . WithLoginHint ( userId )
325- . ExecuteAsync ( ctsInteractive . Token ) ;
376+ . ExecuteAsync ( ctsInteractive . Token )
377+ . ConfigureAwait ( false ) ;
326378 }
327379 }
328380 else
329381 {
330382 AuthenticationResult result = await app . AcquireTokenWithDeviceCode ( scopes ,
331- deviceCodeResult => _deviceCodeFlowCallback ( deviceCodeResult ) )
383+ deviceCodeResult => deviceCodeFlowCallback ( deviceCodeResult ) )
332384 . WithCorrelationId ( connectionId )
333- . ExecuteAsync ( cancellationToken : cts . Token ) ;
385+ . ExecuteAsync ( cancellationToken : cts . Token )
386+ . ConfigureAwait ( false ) ;
334387 return result ;
335388 }
336389 }
@@ -380,6 +433,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
380433 return clientApplicationInstance ;
381434 }
382435
436+ private static string GetAccountPwCacheKey ( SqlAuthenticationParameters parameters )
437+ {
438+ return parameters . Authority + "+" + parameters . UserId ;
439+ }
440+
441+ private static byte [ ] GetHash ( string input )
442+ {
443+ byte [ ] unhashedBytes = Encoding . Unicode . GetBytes ( input ) ;
444+ SHA256 sha256 = SHA256 . Create ( ) ;
445+ byte [ ] hashedBytes = sha256 . ComputeHash ( unhashedBytes ) ;
446+ return hashedBytes ;
447+ }
448+
383449 private IPublicClientApplication CreateClientAppInstance ( PublicClientAppKey publicClientAppKey )
384450 {
385451 IPublicClientApplication publicClientApplication ;
0 commit comments