Skip to content

Commit 3c8ef33

Browse files
Active Directory Default authentication improvements (#1360)
* Active Directory Default improvements * Disable test on Linux * Feedback applied * Update src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs Co-authored-by: Javad <[email protected]> * Add configureawaits Co-authored-by: Javad <[email protected]>
1 parent aa95223 commit 3c8ef33

File tree

2 files changed

+57
-23
lines changed

2 files changed

+57
-23
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -118,28 +118,46 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
118118

119119
string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix;
120120
string[] scopes = new string[] { scope };
121+
TokenRequestContext tokenRequestContext = new(scopes);
122+
123+
/* We split audience from Authority URL here. Audience can be one of the following:
124+
* The Azure AD authority audience enumeration
125+
* The tenant ID, which can be:
126+
* - A GUID (the ID of your Azure AD instance), for single-tenant applications
127+
* - A domain name associated with your Azure AD instance (also for single-tenant applications)
128+
* One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration:
129+
* - `organizations` for a multitenant application
130+
* - `consumers` to sign in users only with their personal accounts
131+
* - `common` to sign in users with their work and school accounts or their personal Microsoft accounts
132+
*
133+
* MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID.
134+
* If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.)
135+
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
136+
**/
121137

122138
int seperatorIndex = parameters.Authority.LastIndexOf('/');
123-
string tenantId = parameters.Authority.Substring(seperatorIndex + 1);
124139
string authority = parameters.Authority.Remove(seperatorIndex + 1);
125-
126-
TokenRequestContext tokenRequestContext = new TokenRequestContext(scopes);
140+
string audience = parameters.Authority.Substring(seperatorIndex + 1);
127141
string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;
128142

129143
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault)
130144
{
131-
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new DefaultAzureCredentialOptions()
145+
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new()
132146
{
133147
AuthorityHost = new Uri(authority),
134-
ManagedIdentityClientId = clientId,
135-
InteractiveBrowserTenantId = tenantId,
136-
SharedTokenCacheTenantId = tenantId,
137-
SharedTokenCacheUsername = clientId,
138-
VisualStudioCodeTenantId = tenantId,
139-
VisualStudioTenantId = tenantId,
148+
SharedTokenCacheTenantId = audience,
149+
VisualStudioCodeTenantId = audience,
150+
VisualStudioTenantId = audience,
140151
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
141152
};
142-
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
153+
154+
// Optionally set clientId when available
155+
if (clientId is not null)
156+
{
157+
defaultAzureCredentialOptions.ManagedIdentityClientId = clientId;
158+
defaultAzureCredentialOptions.SharedTokenCacheUsername = clientId;
159+
}
160+
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
143161
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
144162
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
145163
}
@@ -148,15 +166,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
148166

149167
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI)
150168
{
151-
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
169+
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
152170
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
153171
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
154172
}
155173

156174
AuthenticationResult result;
157175
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
158176
{
159-
AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
177+
AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
160178
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
161179
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
162180
}
@@ -194,13 +212,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
194212
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
195213
.WithCorrelationId(parameters.ConnectionId)
196214
.WithUsername(parameters.UserId)
197-
.ExecuteAsync(cancellationToken: cts.Token);
215+
.ExecuteAsync(cancellationToken: cts.Token)
216+
.ConfigureAwait(false);
198217
}
199218
else
200219
{
201220
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
202221
.WithCorrelationId(parameters.ConnectionId)
203-
.ExecuteAsync(cancellationToken: cts.Token);
222+
.ExecuteAsync(cancellationToken: cts.Token)
223+
.ConfigureAwait(false);
204224
}
205225
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
206226
}
@@ -213,14 +233,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
213233

214234
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
215235
.WithCorrelationId(parameters.ConnectionId)
216-
.ExecuteAsync(cancellationToken: cts.Token);
236+
.ExecuteAsync(cancellationToken: cts.Token)
237+
.ConfigureAwait(false);
217238
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
218239
}
219240
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
220241
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
221242
{
222243
// Fetch available accounts from 'app' instance
223-
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync()).GetEnumerator();
244+
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
224245

225246
IAccount account = default;
226247
if (accounts.MoveNext())
@@ -250,22 +271,22 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
250271
{
251272
// If 'account' is available in 'app', we use the same to acquire token silently.
252273
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
253-
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token);
274+
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
254275
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
255276
}
256277
catch (MsalUiRequiredException)
257278
{
258279
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
259280
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
260281
// or the user needs to perform two factor authentication.
261-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
282+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
262283
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
263284
}
264285
}
265286
else
266287
{
267288
// If no existing 'account' is found, we request user to sign in interactively.
268-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
289+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
269290
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
270291
}
271292
}
@@ -304,7 +325,8 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
304325
.WithCorrelationId(connectionId)
305326
.WithCustomWebUi(_customWebUI)
306327
.WithLoginHint(userId)
307-
.ExecuteAsync(ctsInteractive.Token);
328+
.ExecuteAsync(ctsInteractive.Token)
329+
.ConfigureAwait(false);
308330
}
309331
else
310332
{
@@ -328,15 +350,17 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
328350
return await app.AcquireTokenInteractive(scopes)
329351
.WithCorrelationId(connectionId)
330352
.WithLoginHint(userId)
331-
.ExecuteAsync(ctsInteractive.Token);
353+
.ExecuteAsync(ctsInteractive.Token)
354+
.ConfigureAwait(false);
332355
}
333356
}
334357
else
335358
{
336359
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
337360
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult))
338361
.WithCorrelationId(connectionId)
339-
.ExecuteAsync(cancellationToken: cts.Token);
362+
.ExecuteAsync(cancellationToken: cts.Token)
363+
.ConfigureAwait(false);
340364
return result;
341365
}
342366
}

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ private static void ConnectAndDisconnect(string connectionString, SqlCredential
7676
private static bool IsAADConnStringsSetup() => DataTestUtility.IsAADPasswordConnStrSetup();
7777
private static bool IsManagedIdentitySetup() => DataTestUtility.ManagedIdentitySupported;
7878

79+
[PlatformSpecific(TestPlatforms.Windows)]
80+
[ConditionalFact(nameof(IsAccessTokenSetup), nameof(IsAADConnStringsSetup))]
81+
public static void KustoDatabaseTest()
82+
{
83+
// This is a sample Kusto database that can be connected by any AD account.
84+
using SqlConnection connection = new SqlConnection("Data Source=help.kusto.windows.net; Authentication=Active Directory Default;Trust Server Certificate=True;");
85+
connection.Open();
86+
Assert.True(connection.State == System.Data.ConnectionState.Open);
87+
}
88+
7989
[ConditionalFact(nameof(IsAccessTokenSetup), nameof(IsAADConnStringsSetup))]
8090
public static void AccessTokenTest()
8191
{

0 commit comments

Comments
 (0)