diff --git a/apps/managedidentity/example_test.go b/apps/managedidentity/example_test.go new file mode 100644 index 00000000..46ee6256 --- /dev/null +++ b/apps/managedidentity/example_test.go @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package managedidentity_test + +import ( + "fmt" + + mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" +) + +// This example demonstrates how to create the managed identity client with system assigned +// +// A system-assigned managed identity is enabled directly on an Azure resource (like a VM, App Service, or Function). +// Azure automatically creates this identity and ties it to the lifecycle of the resource — it gets deleted when the resource is deleted. +// Use this when your app only needs one identity and doesn’t need to share it across services. +// Learn more: +// https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#system-assigned-managed-identity +func ExampleNew_systemAssigned() { + systemAssignedClient, err := mi.New(mi.SystemAssigned()) + if err != nil { + fmt.Printf("failed to create client with system-assigned identity: %v", err) + } + _ = systemAssignedClient // Use this client to authenticate to Azure services (e.g., Key Vault, Storage, etc.) + +} + +// This example demonstrates how to create the managed identity client with user assigned +// +// A user-assigned managed identity is a standalone Azure resource that can be assigned to one or more Azure resources. +// User-assigned identities: https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#user-assigned-managed-identity +func ExampleNew_userAssigned() { + clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID + + userAssignedClient, err := mi.New( + mi.UserAssignedClientID(clientID), + ) + if err != nil { + fmt.Printf("failed to create client with user-assigned identity: %v", err) + } + _ = userAssignedClient // Use this client for authentication when stable or shared identity is required +} + +// Client Capabilities ("cp1", etc.) +// 'cp1' is a capability that enables specific client behaviors — for example, +// supporting Conditional Access policies that require additional client capabilities. +// This is mostly relevant in scenarios where the identity is used to access resources +// protected by policies like MFA or device compliance. +// In most cases, you won't need to set this unless required by your Azure AD configuration. +// +// Learn more: +// Client capabilities: https://learn.microsoft.com/entra/msal/python/advanced/client-capabilities +func ExampleWithClientCapabilities() { + clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID + + userAssignedClient, err := mi.New( + mi.UserAssignedClientID(clientID), + + mi.WithClientCapabilities([]string{"cp1"}), + ) + if err != nil { + fmt.Printf("failed to create client with user-assigned identity: %v", err) + } + _ = userAssignedClient // Use this client for authentication when stable or shared identity is required +} diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index ca3de432..49f6fddd 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -11,6 +11,8 @@ package managedidentity import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -82,7 +84,7 @@ const ( tokenName = "Tokens" // App Service - appServiceAPIVersion = "2019-08-01" + appServiceAPIVersion = "2025-03-30" // AzureML azureMLAPIVersion = "2017-09-01" @@ -178,6 +180,7 @@ type Client struct { authParams authority.AuthParams retryPolicyEnabled bool canRefresh *atomic.Value + clientCapabilities string } type AcquireTokenOptions struct { @@ -192,14 +195,35 @@ type AcquireTokenOption func(o *AcquireTokenOptions) // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. func WithClaims(claims string) AcquireTokenOption { return func(o *AcquireTokenOptions) { - o.claims = claims + if claims != "" { + o.claims = claims + } + } +} + +// WithClientCapabilities sets the client capabilities to be used in the request. +// For details see https://learn.microsoft.com/entra/identity/conditional-access/concept-continuous-access-evaluation +// The capabilities are passed as a slice of strings, and empty strings are filtered out. +func WithClientCapabilities(capabilities []string) ClientOption { + return func(o *Client) { + var filteredCapabilities []string + for _, cap := range capabilities { + trimmedCap := strings.TrimSpace(cap) + if trimmedCap != "" { + filteredCapabilities = append(filteredCapabilities, trimmedCap) + } + } + o.clientCapabilities = strings.Join(filteredCapabilities, ",") } } // WithHTTPClient allows for a custom HTTP client to be set. +// if nil, the default HTTP client will be used. func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { return func(c *Client) { - c.httpClient = httpClient + if httpClient != nil { + c.httpClient = httpClient + } } } @@ -323,28 +347,29 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac } c.authParams.Scopes = []string{resource} - // ignore cached access tokens when given claims - if o.claims == "" { - stResp, err := cacheManager.Read(ctx, c.authParams) - if err != nil { - return AuthResult{}, err + stResp, err := cacheManager.Read(ctx, c.authParams) + if err != nil { + return AuthResult{}, err + } + ar, err := base.AuthResultFromStorage(stResp) + if err == nil { + if o.claims != "" { + // When the claims are set, we need to pass on revoked token to MSIv1 (AppService, ServiceFabric) + return c.getToken(ctx, resource, ar.AccessToken) } - ar, err := base.AuthResultFromStorage(stResp) - if err == nil { - if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) { - defer c.canRefresh.Store(false) - if tr, er := c.getToken(ctx, resource); er == nil { - return tr, nil - } + if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) { + defer c.canRefresh.Store(false) + if tr, er := c.getToken(ctx, resource, ""); er == nil { + return tr, nil } - ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) - return ar, err } + ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) + return ar, err } - return c.getToken(ctx, resource) + return c.getToken(ctx, resource, "") } -func (c Client) getToken(ctx context.Context, resource string) (AuthResult, error) { +func (c Client) getToken(ctx context.Context, resource string, revokedToken string) (AuthResult, error) { switch c.source { case AzureArc: return c.acquireTokenForAzureArc(ctx, resource) @@ -355,16 +380,16 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro case DefaultToIMDS: return c.acquireTokenForIMDS(ctx, resource) case AppService: - return c.acquireTokenForAppService(ctx, resource) + return c.acquireTokenForAppService(ctx, resource, revokedToken) case ServiceFabric: - return c.acquireTokenForServiceFabric(ctx, resource) + return c.acquireTokenForServiceFabric(ctx, resource, revokedToken) default: return AuthResult{}, fmt.Errorf("unsupported source %q", c.source) } } -func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (AuthResult, error) { - req, err := createAppServiceAuthRequest(ctx, c.miType, resource) +func (c Client) acquireTokenForAppService(ctx context.Context, resource string, revokedToken string) (AuthResult, error) { + req, err := createAppServiceAuthRequest(ctx, c.miType, resource, revokedToken, c.clientCapabilities) if err != nil { return AuthResult{}, err } @@ -411,8 +436,8 @@ func (c Client) acquireTokenForAzureML(ctx context.Context, resource string) (Au return authResultFromToken(c.authParams, tokenResponse) } -func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string) (AuthResult, error) { - req, err := createServiceFabricAuthRequest(ctx, resource) +func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string, revokedToken string) (AuthResult, error) { + req, err := createServiceFabricAuthRequest(ctx, resource, revokedToken, c.clientCapabilities) if err != nil { return AuthResult{}, err } @@ -569,16 +594,26 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto return r, err } -func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { +func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, revokedToken string, cc string) (*http.Request, error) { identityEndpoint := os.Getenv(identityEndpointEnvVar) req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) if err != nil { return nil, err } req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar)) + q := req.URL.Query() q.Set("api-version", appServiceAPIVersion) q.Set("resource", resource) + + if revokedToken != "" { + q.Set("token_sha256_to_refresh", convertTokenToSHA256HashString(revokedToken)) + } + + if len(cc) > 0 { + q.Set("xms_cc", cc) + } + switch t := id.(type) { case UserAssignedClientID: q.Set(miQueryParameterClientId, string(t)) @@ -594,6 +629,13 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (* return req, nil } +func convertTokenToSHA256HashString(revokedToken string) string { + hash := sha256.New() + hash.Write([]byte(revokedToken)) + hashBytes := hash.Sum(nil) + return hex.EncodeToString(hashBytes) +} + func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { msiEndpoint, err := url.Parse(imdsDefaultEndpoint) if err != nil { diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 32b9c4f5..c46e6244 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -5,6 +5,8 @@ package managedidentity import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -964,7 +966,7 @@ func TestAzureArcErrors(t *testing.T) { }, { name: "Invalid file path", - headerValue: "Basic realm=" + filepath.Join("path", "to", secretKey), + headerValue: basicRealm + filepath.Join("path", "to", secretKey), expectedError: "invalid file path, expected " + testCaseFilePath + ", got " + filepath.Join("path", "to"), }, { @@ -1209,3 +1211,124 @@ func TestRefreshInMultipleRequests(t *testing.T) { } close(ch) } + +func TestWithClientCapabilities_TrimsSpaces(t *testing.T) { + setEnvVars(t, AppService) + mockClient := mock.NewClient() + // Reset cache for clean test + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + capabilitiesWithSpaces := []string{" cp1", " cp2 ", " cp3 "} + expectedCapabilities := "cp1,cp2,cp3" + + client, err := New(SystemAssigned(), + WithHTTPClient(mockClient), + WithClientCapabilities(capabilitiesWithSpaces)) + if err != nil { + t.Fatal(err) + } + + if client.clientCapabilities != expectedCapabilities { + t.Errorf("WithClientCapabilities() did not trim spaces correctly, got: %s, want: %s", client.clientCapabilities, expectedCapabilities) + } +} + +// TestAppServiceWithClaimsAndBadAccessToken tests the scenario where claims are passed +// and a bad access token is retrieved from the cache +func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) { + setEnvVars(t, AppService) + localUrl := &url.URL{} + mockClient := mock.NewClient() + // Second response is a successful token response after retrying with claims + responseBody, err := getSuccessfulResponse(resource, false) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + mockClient.AppendResponse( + mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + ) + mockClient.AppendResponse( + mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + // Reset cache for clean test + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(SystemAssigned(), + WithHTTPClient(mockClient), + WithClientCapabilities([]string{"cp1", "cp2"})) + if err != nil { + t.Fatal(err) + } + + // Call AcquireToken which should trigger token revocation flow + result, err := client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatalf("AcquireToken failed: %v", err) + } + + // Verify token was obtained successfully + if result.AccessToken != token { + t.Fatalf("Expected access token %q, got %q", token, result.AccessToken) + } + + // Call AcquireToken which should trigger token revocation flow + result, err = client.AcquireToken(context.Background(), resource, WithClaims("dummyClaims")) + if err != nil { + t.Fatalf("AcquireToken failed: %v", err) + } + + localUrlQuery := localUrl.Query() + + if localUrlQuery.Get(apiVersionQueryParameterName) != appServiceAPIVersion { + t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, localUrlQuery.Get(apiVersionQueryParameterName)) + } + if r := localUrlQuery.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + if localUrlQuery.Get("xms_cc") != "cp1,cp2" { + t.Fatalf("Expected client capabilities %q, got %q", "cp1,cp2", localUrlQuery.Get("xms_cc")) + } + hash := sha256.Sum256([]byte(token)) + if localUrlQuery.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) { + t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuery.Get("token_sha256_to_refresh")) + } + // Verify token was obtained successfully + if result.AccessToken != token { + t.Fatalf("Expected access token %q, got %q", token, result.AccessToken) + } +} + +func TestConvertTokenToSHA256HashString(t *testing.T) { + tests := []struct { + token string + expectedHash string + }{ + { + token: "test_token", + expectedHash: "cc0af97287543b65da2c7e1476426021826cab166f1e063ed012b855ff819656", + }, + { + token: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~", + expectedHash: "01588d5a948b6c4facd47866877491b42866b5c10a4d342cf168e994101d352a", + }, + { + token: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~", + expectedHash: "29c538690068a8ad1797a391bfe23e7fb817b601fc7b78288cb499ab8fd37947", + }, + } + + for _, test := range tests { + hash := convertTokenToSHA256HashString(test.token) + if hash != test.expectedHash { + t.Fatalf("for token %q, expected %q, got %q", test.token, test.expectedHash, hash) + } + } +} diff --git a/apps/managedidentity/servicefabric.go b/apps/managedidentity/servicefabric.go index 535065e9..2b0f7bb2 100644 --- a/apps/managedidentity/servicefabric.go +++ b/apps/managedidentity/servicefabric.go @@ -9,7 +9,7 @@ import ( "os" ) -func createServiceFabricAuthRequest(ctx context.Context, resource string) (*http.Request, error) { +func createServiceFabricAuthRequest(ctx context.Context, resource string, revokedToken string, cc string) (*http.Request, error) { identityEndpoint := os.Getenv(identityEndpointEnvVar) req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) if err != nil { @@ -20,6 +20,14 @@ func createServiceFabricAuthRequest(ctx context.Context, resource string) (*http q := req.URL.Query() q.Set("api-version", serviceFabricAPIVersion) q.Set("resource", resource) + if revokedToken != "" { + q.Set("token_sha256_to_refresh", convertTokenToSHA256HashString(revokedToken)) + } + + if len(cc) > 0 { + q.Set("xms_cc", cc) + } + req.URL.RawQuery = q.Encode() return req, nil } diff --git a/apps/managedidentity/servicefabric_test.go b/apps/managedidentity/servicefabric_test.go index 37c24b0e..f8063655 100644 --- a/apps/managedidentity/servicefabric_test.go +++ b/apps/managedidentity/servicefabric_test.go @@ -5,6 +5,8 @@ package managedidentity import ( "context" + "crypto/sha256" + "encoding/hex" "net/http" "net/url" "strings" @@ -96,6 +98,78 @@ func TestServiceFabricAcquireTokenReturnsTokenSuccess(t *testing.T) { }) } } + +// TestServiceFabricWithClaimsAndBadAccessToken tests the scenario where claims are passed +// and a bad access token is retrieved from the cache +func TestServiceFabricWithClaimsAndBadAccessToken(t *testing.T) { + setEnvVars(t, ServiceFabric) + localUrl := &url.URL{} + mockClient := mock.NewClient() + // Second response is a successful token response after retrying with claims + responseBody, err := getSuccessfulResponse(resource, false) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + mockClient.AppendResponse( + mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + ) + mockClient.AppendResponse( + mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + // Reset cache for clean test + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(SystemAssigned(), + WithHTTPClient(mockClient), + WithClientCapabilities([]string{"c1", "c2"})) + if err != nil { + t.Fatal(err) + } + + // Call AcquireToken which should trigger token revocation flow + result, err := client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatalf("AcquireToken failed: %v", err) + } + + // Verify token was obtained successfully + if result.AccessToken != token { + t.Fatalf("Expected access token %q, got %q", token, result.AccessToken) + } + + // Call AcquireToken which should trigger token revocation flow + result, err = client.AcquireToken(context.Background(), resource, WithClaims("dummyClaims")) + if err != nil { + t.Fatalf("AcquireToken failed: %v", err) + } + + localUrlQuerry := localUrl.Query() + + if localUrlQuerry.Get(apiVersionQueryParameterName) != serviceFabricAPIVersion { + t.Fatalf("api-version not on %s got %s", serviceFabricAPIVersion, localUrlQuerry.Get(apiVersionQueryParameterName)) + } + if r := localUrlQuerry.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + if localUrlQuerry.Get("xms_cc") != "c1,c2" { + t.Fatalf("Expected client capabilities %q, got %q", "c1,c2", localUrlQuerry.Get("xms_cc")) + } + hash := sha256.Sum256([]byte(token)) + if localUrlQuerry.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) { + t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuerry.Get("token_sha256_to_refresh")) + } + // Verify token was obtained successfully + if result.AccessToken != token { + t.Fatalf("Expected access token %q, got %q", token, result.AccessToken) + } +} + func TestServiceFabricErrors(t *testing.T) { setEnvVars(t, ServiceFabric) mockClient := mock.NewClient()