diff --git a/src/Accounts/Accounts/ChangeLog.md b/src/Accounts/Accounts/ChangeLog.md index 42b17f79e7ec..adfcafac3a43 100644 --- a/src/Accounts/Accounts/ChangeLog.md +++ b/src/Accounts/Accounts/ChangeLog.md @@ -18,6 +18,7 @@ - Additional information about change #1 --> ## Upcoming Release +* Fixed an issue causing `Connect-AzAccount -KeyVaultAccessToken` not working [#13127] * Fixed null reference and method case insensitive in `Invoke-AzRestMethod` ## Version 2.1.0 diff --git a/src/Accounts/Authentication.Test/AuthenticationFactoryTests.cs b/src/Accounts/Authentication.Test/AuthenticationFactoryTests.cs index f535148516a0..440580fbc974 100644 --- a/src/Accounts/Authentication.Test/AuthenticationFactoryTests.cs +++ b/src/Accounts/Authentication.Test/AuthenticationFactoryTests.cs @@ -28,6 +28,8 @@ using Xunit; using Xunit.Abstractions; using System.Text.RegularExpressions; +using System.Net.Http; +using System.Threading; namespace Common.Authentication.Test { @@ -561,5 +563,51 @@ private string GetFunctionsResourceId(string resourceIdOrEndpointName, IAzureEnv return resourceId; } + + [Fact] + [Trait(Category.AcceptanceType, Category.CheckIn)] + public void CanGetServiceClientCredentialsWithAccessToken() + { + AzureSessionInitializer.InitializeAzureSession(); + IAuthenticatorBuilder authenticatorBuilder = new DefaultAuthenticatorBuilder(); + AzureSession.Instance.RegisterComponent(AuthenticatorBuilder.AuthenticatorBuilderKey, () => authenticatorBuilder); + PowerShellTokenCacheProvider factory = new InMemoryTokenCacheProvider(); + AzureSession.Instance.RegisterComponent(PowerShellTokenCacheProvider.PowerShellTokenCacheProviderKey, () => factory); + string tenant = Guid.NewGuid().ToString(); + string userId = "user1@contoso.org"; + var armToken = Guid.NewGuid().ToString(); + var graphToken = Guid.NewGuid().ToString(); + var kvToken = Guid.NewGuid().ToString(); + var account = new AzureAccount + { + Id = userId, + Type = AzureAccount.AccountType.AccessToken + }; + account.SetTenants(tenant); + account.SetAccessToken(armToken); + account.SetProperty(AzureAccount.Property.GraphAccessToken, graphToken); + account.SetProperty(AzureAccount.Property.KeyVaultAccessToken, kvToken); + var authFactory = new AuthenticationFactory(); + var environment = AzureEnvironment.PublicEnvironments.Values.First(); + var mockContext = new AzureContext() + { + Account = account + }; + var credentials = authFactory.GetServiceClientCredentials(mockContext); + VerifyAccessTokenInServiceClientCredentials(credentials, armToken); + credentials = authFactory.GetServiceClientCredentials(mockContext, AzureEnvironment.Endpoint.Graph); + VerifyAccessTokenInServiceClientCredentials(credentials, graphToken); + credentials = authFactory.GetServiceClientCredentials(mockContext, AzureEnvironment.Endpoint.AzureKeyVaultServiceEndpointResourceId); + VerifyAccessTokenInServiceClientCredentials(credentials, kvToken); + } + + private void VerifyAccessTokenInServiceClientCredentials(Microsoft.Rest.ServiceClientCredentials cred, string expected) + { + using (var request = new HttpRequestMessage()) + { + cred.ProcessHttpRequestAsync(request, new CancellationToken()).ConfigureAwait(false).GetAwaiter().GetResult(); + Assert.Equal(expected, request.Headers.Authorization.Parameter); + } + } } } diff --git a/src/Accounts/Authentication/Factories/AuthenticationFactory.cs b/src/Accounts/Authentication/Factories/AuthenticationFactory.cs index 374a46eec289..e839af671d95 100644 --- a/src/Accounts/Authentication/Factories/AuthenticationFactory.cs +++ b/src/Accounts/Authentication/Factories/AuthenticationFactory.cs @@ -438,11 +438,14 @@ private string GetFunctionsResourceId(string resourceIdOrEndpointName, IAzureEnv private string GetEndpointToken(IAzureAccount account, string targetEndpoint) { string tokenKey = AzureAccount.Property.AccessToken; - if (targetEndpoint == AzureEnvironment.Endpoint.Graph) + if (string.Equals(targetEndpoint, AzureEnvironment.Endpoint.Graph, StringComparison.OrdinalIgnoreCase)) { tokenKey = AzureAccount.Property.GraphAccessToken; } - + if (string.Equals(targetEndpoint, AzureEnvironment.Endpoint.AzureKeyVaultServiceEndpointResourceId, StringComparison.OrdinalIgnoreCase)) + { + tokenKey = AzureAccount.Property.KeyVaultAccessToken; + } return account.GetProperty(tokenKey); }