From c7b8d0a099eb35b2563e10c81897d899f9fc8bbb Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 3 Apr 2025 13:19:23 +0100 Subject: [PATCH 01/14] Added support for Token revocation support --- apps/managedidentity/managedidentity.go | 72 +++++++--- apps/managedidentity/managedidentity_test.go | 138 +++++++++++++++++++ 2 files changed, 192 insertions(+), 18 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index ca3de432..947dabf0 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -11,6 +11,8 @@ package managedidentity import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -82,7 +84,7 @@ const ( tokenName = "Tokens" // App Service - appServiceAPIVersion = "2019-08-01" + appServiceAPIVersion = "2025-03-30" // AzureML azureMLAPIVersion = "2017-09-01" @@ -178,6 +180,7 @@ type Client struct { authParams authority.AuthParams retryPolicyEnabled bool canRefresh *atomic.Value + clientCapabilities []string } type AcquireTokenOptions struct { @@ -192,14 +195,34 @@ type AcquireTokenOption func(o *AcquireTokenOptions) // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. func WithClaims(claims string) AcquireTokenOption { return func(o *AcquireTokenOptions) { - o.claims = claims + if claims != "" { + o.claims = claims + } + } +} + +// WithClientCapabilities sets the client capabilities to be used in the request. +// This is used to enable specific features or behaviors in the token request. +// The capabilities are passed as a slice of strings, and empty strings are filtered out. +func WithClientCapabilities(capabilities []string) ClientOption { + return func(o *Client) { + var filteredCapabilities []string + for _, cap := range capabilities { + if cap != "" { + filteredCapabilities = append(filteredCapabilities, cap) + } + } + o.clientCapabilities = filteredCapabilities } } // WithHTTPClient allows for a custom HTTP client to be set. +// if nil, the default HTTP client will be used. func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { return func(c *Client) { - c.httpClient = httpClient + if httpClient != nil { + c.httpClient = httpClient + } } } @@ -323,17 +346,19 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac } c.authParams.Scopes = []string{resource} - // ignore cached access tokens when given claims - if o.claims == "" { - stResp, err := cacheManager.Read(ctx, c.authParams) - if err != nil { - return AuthResult{}, err - } - ar, err := base.AuthResultFromStorage(stResp) - if err == nil { + stResp, err := cacheManager.Read(ctx, c.authParams) + if err != nil { + return AuthResult{}, err + } + ar, err := base.AuthResultFromStorage(stResp) + if err == nil { + if o.claims != "" { + // When the claims are set, we need to passon bad/old token + return c.getToken(ctx, resource, ar.AccessToken) + } else { if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) { defer c.canRefresh.Store(false) - if tr, er := c.getToken(ctx, resource); er == nil { + if tr, er := c.getToken(ctx, resource, o.claims); er == nil { return tr, nil } } @@ -341,10 +366,10 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac return ar, err } } - return c.getToken(ctx, resource) + return c.getToken(ctx, resource, "") } -func (c Client) getToken(ctx context.Context, resource string) (AuthResult, error) { +func (c Client) getToken(ctx context.Context, resource string, badToken string) (AuthResult, error) { switch c.source { case AzureArc: return c.acquireTokenForAzureArc(ctx, resource) @@ -355,7 +380,7 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro case DefaultToIMDS: return c.acquireTokenForIMDS(ctx, resource) case AppService: - return c.acquireTokenForAppService(ctx, resource) + return c.acquireTokenForAppService(ctx, resource, badToken) case ServiceFabric: return c.acquireTokenForServiceFabric(ctx, resource) default: @@ -363,8 +388,8 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro } } -func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (AuthResult, error) { - req, err := createAppServiceAuthRequest(ctx, c.miType, resource) +func (c Client) acquireTokenForAppService(ctx context.Context, resource string, badToken string) (AuthResult, error) { + req, err := createAppServiceAuthRequest(ctx, c.miType, resource, badToken, c.clientCapabilities) if err != nil { return AuthResult{}, err } @@ -569,16 +594,27 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto return r, err } -func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { +func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, badToken string, cc []string) (*http.Request, error) { identityEndpoint := os.Getenv(identityEndpointEnvVar) req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) if err != nil { return nil, err } req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar)) + q := req.URL.Query() q.Set("api-version", appServiceAPIVersion) q.Set("resource", resource) + + if badToken != "" { + hash := sha256.Sum256([]byte(badToken)) + q.Set("token_sha256_to_refresh", hex.EncodeToString(hash[:])) + } + + if len(cc) > 0 { + q.Set("xms_cc", strings.Join(cc, ",")) + } + switch t := id.(type) { case UserAssignedClientID: q.Set(miQueryParameterClientId, string(t)) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 32b9c4f5..e553241a 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -5,6 +5,8 @@ package managedidentity import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -696,6 +698,71 @@ func TestAppServiceAcquireTokenReturnsTokenSuccess(t *testing.T) { } } +// TestAppServiceWithClientCapabilities tests the scenario when App Service includes client capabilities +func TestAppServiceWithClientCapabilities(t *testing.T) { + setEnvVars(t, AppService) + + testCases := []struct { + name string + expectError bool + expectedStatusCode int + expectedToken string + expectedCapabilities string + capabilities []string + }{ + { + name: "Token Request with Client Capabilities", + expectError: false, + expectedStatusCode: http.StatusOK, + expectedToken: token, + expectedCapabilities: "c1,c2", + capabilities: []string{"c1", "c2"}, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + mockClient := mock.NewClient() + localUrl := &url.URL{} + responseBody, err := getSuccessfulResponse(resource, false) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + client, err := New(SystemAssigned(), WithHTTPClient(mockClient), WithClientCapabilities(testCase.capabilities)) + if err != nil { + t.Fatal(err) + } + + result, err := client.AcquireToken(context.Background(), resource) + if testCase.expectError { + if err == nil { + t.Fatal("Expected error but got nil") + } + var callErr errors.CallErr + if !errors.As(err, &callErr) { + t.Fatalf("Expected error of type CallErr, got %T", err) + } + } else { + if err != nil { + t.Fatalf("AcquireToken failed: %v", err) + } + if result.AccessToken != testCase.expectedToken { + t.Fatalf("Expected access token %q, got %q", testCase.expectedToken, result.AccessToken) + } + localUrlQuerry := localUrl.Query() + if localUrlQuerry.Get("xms_cc") != testCase.expectedCapabilities { + t.Fatalf("Expected client capabilities %q, got %q", testCase.expectedCapabilities, localUrlQuerry.Get("xms_cc")) + } + } + }) + } +} + func TestAzureMLAcquireTokenReturnsTokenSuccess(t *testing.T) { defaultClientID := "A" t.Setenv("DEFAULT_IDENTITY_CLIENT_ID", defaultClientID) @@ -1209,3 +1276,74 @@ func TestRefreshInMultipleRequests(t *testing.T) { } close(ch) } + +// TestAppServiceWithClaimsAndBadAccessToken tests the scenario where claims are passed +// and a bad access token is retrieved from the cache +func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) { + setEnvVars(t, AppService) + localUrl := &url.URL{} + mockClient := mock.NewClient() + // Second response is a successful token response after retrying with claims + responseBody, err := getSuccessfulResponse(resource, false) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + mockClient.AppendResponse( + mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + ) + mockClient.AppendResponse( + mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + // Reset cache for clean test + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(SystemAssigned(), + WithHTTPClient(mockClient), + WithClientCapabilities([]string{"c1", "c2"})) + if err != nil { + t.Fatal(err) + } + + // Call AcquireToken which should trigger token revocation flow + result, err := client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatalf("AcquireToken failed: %v", err) + } + + // Verify token was obtained successfully + if result.AccessToken != token { + t.Fatalf("Expected access token %q, got %q", token, result.AccessToken) + } + + // Call AcquireToken which should trigger token revocation flow + result, err = client.AcquireToken(context.Background(), resource, WithClaims("dummyClaims")) + if err != nil { + t.Fatalf("AcquireToken failed: %v", err) + } + + localUrlQuerry := localUrl.Query() + + if localUrlQuerry.Get(apiVersionQueryParameterName) != appServiceAPIVersion { + t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, localUrlQuerry.Get(apiVersionQueryParameterName)) + } + if r := localUrlQuerry.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + if localUrlQuerry.Get("xms_cc") != "c1,c2" { + t.Fatalf("Expected client capabilities %q, got %q", "c1,c2", localUrlQuerry.Get("xms_cc")) + } + hash := sha256.Sum256([]byte(token)) + if localUrlQuerry.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) { + t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuerry.Get("token_sha256_to_refresh")) + } + // Verify token was obtained successfully + if result.AccessToken != token { + t.Fatalf("Expected access token %q, got %q", token, result.AccessToken) + } +} From b15c36b079a2a2302e2194e5149a2c6786547e9d Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 3 Apr 2025 15:09:08 +0100 Subject: [PATCH 02/14] Removed one unused test --- apps/managedidentity/managedidentity_test.go | 65 -------------------- 1 file changed, 65 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index e553241a..d12d397e 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -698,71 +698,6 @@ func TestAppServiceAcquireTokenReturnsTokenSuccess(t *testing.T) { } } -// TestAppServiceWithClientCapabilities tests the scenario when App Service includes client capabilities -func TestAppServiceWithClientCapabilities(t *testing.T) { - setEnvVars(t, AppService) - - testCases := []struct { - name string - expectError bool - expectedStatusCode int - expectedToken string - expectedCapabilities string - capabilities []string - }{ - { - name: "Token Request with Client Capabilities", - expectError: false, - expectedStatusCode: http.StatusOK, - expectedToken: token, - expectedCapabilities: "c1,c2", - capabilities: []string{"c1", "c2"}, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - mockClient := mock.NewClient() - localUrl := &url.URL{} - responseBody, err := getSuccessfulResponse(resource, false) - if err != nil { - t.Fatalf(errorFormingJsonResponse, err.Error()) - } - mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), - mock.WithBody(responseBody), - mock.WithCallback(func(r *http.Request) { - localUrl = r.URL - })) - client, err := New(SystemAssigned(), WithHTTPClient(mockClient), WithClientCapabilities(testCase.capabilities)) - if err != nil { - t.Fatal(err) - } - - result, err := client.AcquireToken(context.Background(), resource) - if testCase.expectError { - if err == nil { - t.Fatal("Expected error but got nil") - } - var callErr errors.CallErr - if !errors.As(err, &callErr) { - t.Fatalf("Expected error of type CallErr, got %T", err) - } - } else { - if err != nil { - t.Fatalf("AcquireToken failed: %v", err) - } - if result.AccessToken != testCase.expectedToken { - t.Fatalf("Expected access token %q, got %q", testCase.expectedToken, result.AccessToken) - } - localUrlQuerry := localUrl.Query() - if localUrlQuerry.Get("xms_cc") != testCase.expectedCapabilities { - t.Fatalf("Expected client capabilities %q, got %q", testCase.expectedCapabilities, localUrlQuerry.Get("xms_cc")) - } - } - }) - } -} - func TestAzureMLAcquireTokenReturnsTokenSuccess(t *testing.T) { defaultClientID := "A" t.Setenv("DEFAULT_IDENTITY_CLIENT_ID", defaultClientID) From e3d08ad81fd10c47fc2b6f93233f5c560c4e8fb9 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 3 Apr 2025 15:41:42 +0100 Subject: [PATCH 03/14] Updated documentation --- apps/managedidentity/managedidentity.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 947dabf0..a30e16d0 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -202,7 +202,7 @@ func WithClaims(claims string) AcquireTokenOption { } // WithClientCapabilities sets the client capabilities to be used in the request. -// This is used to enable specific features or behaviors in the token request. +// For details see https://learn.microsoft.com/en-us/entra/identity/conditional-access/concept-continuous-access-evaluation // The capabilities are passed as a slice of strings, and empty strings are filtered out. func WithClientCapabilities(capabilities []string) ClientOption { return func(o *Client) { @@ -353,7 +353,7 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac ar, err := base.AuthResultFromStorage(stResp) if err == nil { if o.claims != "" { - // When the claims are set, we need to passon bad/old token + // When the claims are set, we need to pass on revoked token to MSIv1 (AppService, ServiceFabric) return c.getToken(ctx, resource, ar.AccessToken) } else { if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) { From 4742a02e9b82bb1e0fe7d22e52236f1fe7286778 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 4 Apr 2025 09:41:52 +0100 Subject: [PATCH 04/14] Update managedidentity.go --- apps/managedidentity/managedidentity.go | 31 ++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index a30e16d0..abaeb060 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -202,7 +202,7 @@ func WithClaims(claims string) AcquireTokenOption { } // WithClientCapabilities sets the client capabilities to be used in the request. -// For details see https://learn.microsoft.com/en-us/entra/identity/conditional-access/concept-continuous-access-evaluation +// For details see https://learn.microsoft.com/entra/identity/conditional-access/concept-continuous-access-evaluation // The capabilities are passed as a slice of strings, and empty strings are filtered out. func WithClientCapabilities(capabilities []string) ClientOption { return func(o *Client) { @@ -355,21 +355,20 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac if o.claims != "" { // When the claims are set, we need to pass on revoked token to MSIv1 (AppService, ServiceFabric) return c.getToken(ctx, resource, ar.AccessToken) - } else { - if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) { - defer c.canRefresh.Store(false) - if tr, er := c.getToken(ctx, resource, o.claims); er == nil { - return tr, nil - } + } + if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) { + defer c.canRefresh.Store(false) + if tr, er := c.getToken(ctx, resource, ""); er == nil { + return tr, nil } - ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) - return ar, err } + ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) + return ar, err } return c.getToken(ctx, resource, "") } -func (c Client) getToken(ctx context.Context, resource string, badToken string) (AuthResult, error) { +func (c Client) getToken(ctx context.Context, resource string, revokedToken string) (AuthResult, error) { switch c.source { case AzureArc: return c.acquireTokenForAzureArc(ctx, resource) @@ -380,7 +379,7 @@ func (c Client) getToken(ctx context.Context, resource string, badToken string) case DefaultToIMDS: return c.acquireTokenForIMDS(ctx, resource) case AppService: - return c.acquireTokenForAppService(ctx, resource, badToken) + return c.acquireTokenForAppService(ctx, resource, revokedToken) case ServiceFabric: return c.acquireTokenForServiceFabric(ctx, resource) default: @@ -388,8 +387,8 @@ func (c Client) getToken(ctx context.Context, resource string, badToken string) } } -func (c Client) acquireTokenForAppService(ctx context.Context, resource string, badToken string) (AuthResult, error) { - req, err := createAppServiceAuthRequest(ctx, c.miType, resource, badToken, c.clientCapabilities) +func (c Client) acquireTokenForAppService(ctx context.Context, resource string, revokedToken string) (AuthResult, error) { + req, err := createAppServiceAuthRequest(ctx, c.miType, resource, revokedToken, c.clientCapabilities) if err != nil { return AuthResult{}, err } @@ -594,7 +593,7 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto return r, err } -func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, badToken string, cc []string) (*http.Request, error) { +func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, revokedToken string, cc []string) (*http.Request, error) { identityEndpoint := os.Getenv(identityEndpointEnvVar) req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) if err != nil { @@ -606,8 +605,8 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, ba q.Set("api-version", appServiceAPIVersion) q.Set("resource", resource) - if badToken != "" { - hash := sha256.Sum256([]byte(badToken)) + if revokedToken != "" { + hash := sha256.Sum256([]byte(revokedToken)) q.Set("token_sha256_to_refresh", hex.EncodeToString(hash[:])) } From 96554ad8e0668a165f35149299152a83bae3758f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 4 Apr 2025 10:37:58 +0100 Subject: [PATCH 05/14] Updated Example test --- apps/managedidentity/example_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 apps/managedidentity/example_test.go diff --git a/apps/managedidentity/example_test.go b/apps/managedidentity/example_test.go new file mode 100644 index 00000000..a9e6e08f --- /dev/null +++ b/apps/managedidentity/example_test.go @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package managedidentity_test + +import ( + mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" +) + +func ExampleNew() { + // System assigned Client + miSystemassignedClient, err := mi.New(mi.SystemAssigned()) + if err != nil { + // TODO: Handle error + } + _ = miSystemassignedClient + + // User assigned Client + clientId := "ClientId" // TODO: replace with your Managed Identity Id + + miClientIdAssignedClient, err := mi.New(mi.UserAssignedClientID(clientId), mi.WithClientCapabilities([]string{"cp1"})) + if err != nil { + // TODO: Handle error + } + _ = miClientIdAssignedClient +} From 16f873b606ac65675487f05d99d0fb3776329fab Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 4 Apr 2025 15:21:18 +0100 Subject: [PATCH 06/14] Added conversion from token string to SHA256 hash --- apps/managedidentity/managedidentity.go | 12 +++++++++--- apps/managedidentity/managedidentity_test.go | 10 ++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index abaeb060..55fe145b 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -12,7 +12,7 @@ package managedidentity import ( "context" "crypto/sha256" - "encoding/hex" + "encoding/base64" "encoding/json" "fmt" "io" @@ -606,8 +606,7 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, re q.Set("resource", resource) if revokedToken != "" { - hash := sha256.Sum256([]byte(revokedToken)) - q.Set("token_sha256_to_refresh", hex.EncodeToString(hash[:])) + q.Set("token_sha256_to_refresh", convertTokenToSHA256HashString(revokedToken)) } if len(cc) > 0 { @@ -629,6 +628,13 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, re return req, nil } +func convertTokenToSHA256HashString(revokedToken string) string { + hash := sha256.New() + hash.Write([]byte(revokedToken)) + hashBytes := hash.Sum(nil) + return base64.StdEncoding.EncodeToString(hashBytes) +} + func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { msiEndpoint, err := url.Parse(imdsDefaultEndpoint) if err != nil { diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index d12d397e..beb0182b 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -1282,3 +1282,13 @@ func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) { t.Fatalf("Expected access token %q, got %q", token, result.AccessToken) } } + +func TestConvertTokenToSHA256HashString(t *testing.T) { + // Test with a valid token + token := "test_token" + expectedHash := "zAr5codUO2XaLH4UdkJgIYJsqxZvHgY+0BK4Vf+BllY=" + hash := convertTokenToSHA256HashString(token) + if hash != expectedHash { + t.Fatalf("expected %q, got %q", expectedHash, hash) + } +} From 20525281346a2fa573d51a4799449e52ad835646 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 7 Apr 2025 10:24:09 +0100 Subject: [PATCH 07/14] Updated the encoding --- apps/managedidentity/managedidentity.go | 4 ++-- apps/managedidentity/managedidentity_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 55fe145b..28998d31 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -12,7 +12,7 @@ package managedidentity import ( "context" "crypto/sha256" - "encoding/base64" + "encoding/hex" "encoding/json" "fmt" "io" @@ -632,7 +632,7 @@ func convertTokenToSHA256HashString(revokedToken string) string { hash := sha256.New() hash.Write([]byte(revokedToken)) hashBytes := hash.Sum(nil) - return base64.StdEncoding.EncodeToString(hashBytes) + return hex.EncodeToString(hashBytes) } func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index beb0182b..c9db0971 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -1286,7 +1286,7 @@ func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) { func TestConvertTokenToSHA256HashString(t *testing.T) { // Test with a valid token token := "test_token" - expectedHash := "zAr5codUO2XaLH4UdkJgIYJsqxZvHgY+0BK4Vf+BllY=" + expectedHash := "cc0af97287543b65da2c7e1476426021826cab166f1e063ed012b855ff819656" hash := convertTokenToSHA256HashString(token) if hash != expectedHash { t.Fatalf("expected %q, got %q", expectedHash, hash) From 6817fc2307d29e1bf49ca1830e87305e2c1ac07a Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 7 Apr 2025 10:54:28 +0100 Subject: [PATCH 08/14] Added couple of more tests --- apps/managedidentity/managedidentity_test.go | 31 +++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index c9db0971..2968ab9c 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -966,7 +966,7 @@ func TestAzureArcErrors(t *testing.T) { }, { name: "Invalid file path", - headerValue: "Basic realm=" + filepath.Join("path", "to", secretKey), + headerValue: basicRealm + filepath.Join("path", "to", secretKey), expectedError: "invalid file path, expected " + testCaseFilePath + ", got " + filepath.Join("path", "to"), }, { @@ -1284,11 +1284,28 @@ func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) { } func TestConvertTokenToSHA256HashString(t *testing.T) { - // Test with a valid token - token := "test_token" - expectedHash := "cc0af97287543b65da2c7e1476426021826cab166f1e063ed012b855ff819656" - hash := convertTokenToSHA256HashString(token) - if hash != expectedHash { - t.Fatalf("expected %q, got %q", expectedHash, hash) + tests := []struct { + token string + expectedHash string + }{ + { + token: "test_token", + expectedHash: "cc0af97287543b65da2c7e1476426021826cab166f1e063ed012b855ff819656", + }, + { + token: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~", + expectedHash: "01588d5a948b6c4facd47866877491b42866b5c10a4d342cf168e994101d352a", + }, + { + token: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~", + expectedHash: "29c538690068a8ad1797a391bfe23e7fb817b601fc7b78288cb499ab8fd37947", + }, + } + + for _, test := range tests { + hash := convertTokenToSHA256HashString(test.token) + if hash != test.expectedHash { + t.Fatalf("for token %q, expected %q, got %q", test.token, test.expectedHash, hash) + } } } From faaa17ea266c3cc0a1279199727a8683490a444d Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 8 Apr 2025 16:00:48 +0100 Subject: [PATCH 09/14] Added support for Service Fabric --- apps/managedidentity/managedidentity.go | 6 +- apps/managedidentity/servicefabric.go | 11 +++- apps/managedidentity/servicefabric_test.go | 74 ++++++++++++++++++++++ 3 files changed, 87 insertions(+), 4 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 28998d31..5a370813 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -381,7 +381,7 @@ func (c Client) getToken(ctx context.Context, resource string, revokedToken stri case AppService: return c.acquireTokenForAppService(ctx, resource, revokedToken) case ServiceFabric: - return c.acquireTokenForServiceFabric(ctx, resource) + return c.acquireTokenForServiceFabric(ctx, resource, revokedToken) default: return AuthResult{}, fmt.Errorf("unsupported source %q", c.source) } @@ -435,8 +435,8 @@ func (c Client) acquireTokenForAzureML(ctx context.Context, resource string) (Au return authResultFromToken(c.authParams, tokenResponse) } -func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string) (AuthResult, error) { - req, err := createServiceFabricAuthRequest(ctx, resource) +func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string, revokedToken string) (AuthResult, error) { + req, err := createServiceFabricAuthRequest(ctx, resource, revokedToken, c.clientCapabilities) if err != nil { return AuthResult{}, err } diff --git a/apps/managedidentity/servicefabric.go b/apps/managedidentity/servicefabric.go index 535065e9..be740fc7 100644 --- a/apps/managedidentity/servicefabric.go +++ b/apps/managedidentity/servicefabric.go @@ -7,9 +7,10 @@ import ( "context" "net/http" "os" + "strings" ) -func createServiceFabricAuthRequest(ctx context.Context, resource string) (*http.Request, error) { +func createServiceFabricAuthRequest(ctx context.Context, resource string, revokedToken string, cc []string) (*http.Request, error) { identityEndpoint := os.Getenv(identityEndpointEnvVar) req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) if err != nil { @@ -20,6 +21,14 @@ func createServiceFabricAuthRequest(ctx context.Context, resource string) (*http q := req.URL.Query() q.Set("api-version", serviceFabricAPIVersion) q.Set("resource", resource) + if revokedToken != "" { + q.Set("token_sha256_to_refresh", convertTokenToSHA256HashString(revokedToken)) + } + + if len(cc) > 0 { + q.Set("xms_cc", strings.Join(cc, ",")) + } + req.URL.RawQuery = q.Encode() return req, nil } diff --git a/apps/managedidentity/servicefabric_test.go b/apps/managedidentity/servicefabric_test.go index 37c24b0e..028f36ad 100644 --- a/apps/managedidentity/servicefabric_test.go +++ b/apps/managedidentity/servicefabric_test.go @@ -5,6 +5,8 @@ package managedidentity import ( "context" + "crypto/sha256" + "encoding/hex" "net/http" "net/url" "strings" @@ -96,6 +98,78 @@ func TestServiceFabricAcquireTokenReturnsTokenSuccess(t *testing.T) { }) } } + +// TestAppServiceWithClaimsAndBadAccessToken tests the scenario where claims are passed +// and a bad access token is retrieved from the cache +func TestServiceFabricWithClaimsAndBadAccessToken(t *testing.T) { + setEnvVars(t, ServiceFabric) + localUrl := &url.URL{} + mockClient := mock.NewClient() + // Second response is a successful token response after retrying with claims + responseBody, err := getSuccessfulResponse(resource, false) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + mockClient.AppendResponse( + mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + ) + mockClient.AppendResponse( + mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + // Reset cache for clean test + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(SystemAssigned(), + WithHTTPClient(mockClient), + WithClientCapabilities([]string{"c1", "c2"})) + if err != nil { + t.Fatal(err) + } + + // Call AcquireToken which should trigger token revocation flow + result, err := client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatalf("AcquireToken failed: %v", err) + } + + // Verify token was obtained successfully + if result.AccessToken != token { + t.Fatalf("Expected access token %q, got %q", token, result.AccessToken) + } + + // Call AcquireToken which should trigger token revocation flow + result, err = client.AcquireToken(context.Background(), resource, WithClaims("dummyClaims")) + if err != nil { + t.Fatalf("AcquireToken failed: %v", err) + } + + localUrlQuerry := localUrl.Query() + + if localUrlQuerry.Get(apiVersionQueryParameterName) != serviceFabricAPIVersion { + t.Fatalf("api-version not on %s got %s", serviceFabricAPIVersion, localUrlQuerry.Get(apiVersionQueryParameterName)) + } + if r := localUrlQuerry.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + if localUrlQuerry.Get("xms_cc") != "c1,c2" { + t.Fatalf("Expected client capabilities %q, got %q", "c1,c2", localUrlQuerry.Get("xms_cc")) + } + hash := sha256.Sum256([]byte(token)) + if localUrlQuerry.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) { + t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuerry.Get("token_sha256_to_refresh")) + } + // Verify token was obtained successfully + if result.AccessToken != token { + t.Fatalf("Expected access token %q, got %q", token, result.AccessToken) + } +} + func TestServiceFabricErrors(t *testing.T) { setEnvVars(t, ServiceFabric) mockClient := mock.NewClient() From 35be0ab09cdebf07e15dd2c28dcd01d9c34ebf26 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 14 Apr 2025 10:53:30 +0100 Subject: [PATCH 10/14] Updated CC as strings --- apps/managedidentity/managedidentity.go | 8 ++++---- apps/managedidentity/servicefabric.go | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 5a370813..6a2eafeb 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -180,7 +180,7 @@ type Client struct { authParams authority.AuthParams retryPolicyEnabled bool canRefresh *atomic.Value - clientCapabilities []string + clientCapabilities string } type AcquireTokenOptions struct { @@ -212,7 +212,7 @@ func WithClientCapabilities(capabilities []string) ClientOption { filteredCapabilities = append(filteredCapabilities, cap) } } - o.clientCapabilities = filteredCapabilities + o.clientCapabilities = strings.Join(filteredCapabilities, ",") } } @@ -593,7 +593,7 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto return r, err } -func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, revokedToken string, cc []string) (*http.Request, error) { +func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, revokedToken string, cc string) (*http.Request, error) { identityEndpoint := os.Getenv(identityEndpointEnvVar) req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) if err != nil { @@ -610,7 +610,7 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, re } if len(cc) > 0 { - q.Set("xms_cc", strings.Join(cc, ",")) + q.Set("xms_cc", cc) } switch t := id.(type) { diff --git a/apps/managedidentity/servicefabric.go b/apps/managedidentity/servicefabric.go index be740fc7..2b0f7bb2 100644 --- a/apps/managedidentity/servicefabric.go +++ b/apps/managedidentity/servicefabric.go @@ -7,10 +7,9 @@ import ( "context" "net/http" "os" - "strings" ) -func createServiceFabricAuthRequest(ctx context.Context, resource string, revokedToken string, cc []string) (*http.Request, error) { +func createServiceFabricAuthRequest(ctx context.Context, resource string, revokedToken string, cc string) (*http.Request, error) { identityEndpoint := os.Getenv(identityEndpointEnvVar) req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) if err != nil { @@ -26,7 +25,7 @@ func createServiceFabricAuthRequest(ctx context.Context, resource string, revoke } if len(cc) > 0 { - q.Set("xms_cc", strings.Join(cc, ",")) + q.Set("xms_cc", cc) } req.URL.RawQuery = q.Encode() From cc71d57b27076ab954b0a65f717568c61580a47f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 14 Apr 2025 11:34:57 +0100 Subject: [PATCH 11/14] Update example_test.go --- apps/managedidentity/example_test.go | 56 +++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/apps/managedidentity/example_test.go b/apps/managedidentity/example_test.go index a9e6e08f..6bfff8aa 100644 --- a/apps/managedidentity/example_test.go +++ b/apps/managedidentity/example_test.go @@ -4,23 +4,61 @@ package managedidentity_test import ( + "fmt" + mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" ) func ExampleNew() { - // System assigned Client - miSystemassignedClient, err := mi.New(mi.SystemAssigned()) + // =============================== + // System-Assigned Managed Identity + // =============================== + + // A system-assigned managed identity is enabled directly on an Azure resource (like a VM, App Service, or Function). + // Azure automatically creates this identity and ties it to the lifecycle of the resource — it gets deleted when the resource is deleted. + // Use this when your app only needs one identity and doesn’t need to share it across services. + systemAssignedClient, err := mi.New(mi.SystemAssigned()) if err != nil { - // TODO: Handle error + fmt.Printf("failed to create client with system-assigned identity: %v", err) } - _ = miSystemassignedClient + _ = systemAssignedClient // Use this client to authenticate to Azure services (e.g., Key Vault, Storage, etc.) + + // Learn more: + // https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#system-assigned-managed-identity + + // ============================= + // User-Assigned Managed Identity + // ============================= + + // A user-assigned managed identity is a standalone Azure resource that can be assigned to one or more Azure resources. + // It's ideal when: + // - You need a consistent identity across services (e.g., multiple apps accessing the same Key Vault) + // - You want to control the lifecycle of the identity independently of the resource + // - You need fine-grained role assignments or separation of concerns - // User assigned Client - clientId := "ClientId" // TODO: replace with your Managed Identity Id + clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID - miClientIdAssignedClient, err := mi.New(mi.UserAssignedClientID(clientId), mi.WithClientCapabilities([]string{"cp1"})) + userAssignedClient, err := mi.New( + mi.UserAssignedClientID(clientID), + + // =========================================== + // Optional: Client Capabilities ("cp1", etc.) + // =========================================== + + // 'cp1' is a capability that enables specific client behaviors — for example, + // supporting Conditional Access policies that require additional client capabilities. + // This is mostly relevant in scenarios where the identity is used to access resources + // protected by policies like MFA or device compliance. + + // In most cases, you won't need to set this unless required by your Azure AD configuration. + mi.WithClientCapabilities([]string{"cp1"}), + ) if err != nil { - // TODO: Handle error + fmt.Printf("failed to create client with user-assigned identity: %v", err) } - _ = miClientIdAssignedClient + _ = userAssignedClient // Use this client for authentication when stable or shared identity is required + + // Learn more: + // - User-assigned identities: https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#user-assigned-managed-identity + // - Client capabilities: https://learn.microsoft.com/azure/active-directory/develop/msal-client-capabilities } From 7d901e6bb6d08739f5c0c3a89f68eac3c97fd364 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 17 Apr 2025 17:13:15 +0100 Subject: [PATCH 12/14] Updated example --- apps/managedidentity/example_test.go | 65 +++++++++++++++------------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/apps/managedidentity/example_test.go b/apps/managedidentity/example_test.go index 6bfff8aa..36c4a3fd 100644 --- a/apps/managedidentity/example_test.go +++ b/apps/managedidentity/example_test.go @@ -9,56 +9,59 @@ import ( mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" ) +// =============================== +// System-Assigned Managed Identity +// =============================== +// A system-assigned managed identity is enabled directly on an Azure resource (like a VM, App Service, or Function). +// Azure automatically creates this identity and ties it to the lifecycle of the resource — it gets deleted when the resource is deleted. +// Use this when your app only needs one identity and doesn’t need to share it across services. +// Learn more: +// https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#system-assigned-managed-identity +// ============================= +// User-Assigned Managed Identity +// ============================= +// A user-assigned managed identity is a standalone Azure resource that can be assigned to one or more Azure resources. +// User-assigned identities: https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#user-assigned-managed-identity + func ExampleNew() { - // =============================== - // System-Assigned Managed Identity - // =============================== - // A system-assigned managed identity is enabled directly on an Azure resource (like a VM, App Service, or Function). - // Azure automatically creates this identity and ties it to the lifecycle of the resource — it gets deleted when the resource is deleted. - // Use this when your app only needs one identity and doesn’t need to share it across services. systemAssignedClient, err := mi.New(mi.SystemAssigned()) if err != nil { fmt.Printf("failed to create client with system-assigned identity: %v", err) } _ = systemAssignedClient // Use this client to authenticate to Azure services (e.g., Key Vault, Storage, etc.) - // Learn more: - // https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#system-assigned-managed-identity - - // ============================= - // User-Assigned Managed Identity - // ============================= - - // A user-assigned managed identity is a standalone Azure resource that can be assigned to one or more Azure resources. - // It's ideal when: - // - You need a consistent identity across services (e.g., multiple apps accessing the same Key Vault) - // - You want to control the lifecycle of the identity independently of the resource - // - You need fine-grained role assignments or separation of concerns - clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID userAssignedClient, err := mi.New( mi.UserAssignedClientID(clientID), + ) + if err != nil { + fmt.Printf("failed to create client with user-assigned identity: %v", err) + } + _ = userAssignedClient // Use this client for authentication when stable or shared identity is required +} - // =========================================== - // Optional: Client Capabilities ("cp1", etc.) - // =========================================== +// =========================================== +// Optional: Client Capabilities ("cp1", etc.) +// =========================================== +// 'cp1' is a capability that enables specific client behaviors — for example, +// supporting Conditional Access policies that require additional client capabilities. +// This is mostly relevant in scenarios where the identity is used to access resources +// protected by policies like MFA or device compliance. +// In most cases, you won't need to set this unless required by your Azure AD configuration. +// Learn more: +// - Client capabilities: https://learn.microsoft.com/en-us/entra/msal/python/advanced/client-capabilities +func ExampleWithClientCapabilities() { + clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID - // 'cp1' is a capability that enables specific client behaviors — for example, - // supporting Conditional Access policies that require additional client capabilities. - // This is mostly relevant in scenarios where the identity is used to access resources - // protected by policies like MFA or device compliance. + userAssignedClient, err := mi.New( + mi.UserAssignedClientID(clientID), - // In most cases, you won't need to set this unless required by your Azure AD configuration. mi.WithClientCapabilities([]string{"cp1"}), ) if err != nil { fmt.Printf("failed to create client with user-assigned identity: %v", err) } _ = userAssignedClient // Use this client for authentication when stable or shared identity is required - - // Learn more: - // - User-assigned identities: https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#user-assigned-managed-identity - // - Client capabilities: https://learn.microsoft.com/azure/active-directory/develop/msal-client-capabilities } From 1c9fa261d218dfeb706925114a6c26261d6d533a Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 18 Apr 2025 11:22:38 +0100 Subject: [PATCH 13/14] Updated sample example --- apps/managedidentity/example_test.go | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/apps/managedidentity/example_test.go b/apps/managedidentity/example_test.go index 36c4a3fd..9bcbd338 100644 --- a/apps/managedidentity/example_test.go +++ b/apps/managedidentity/example_test.go @@ -9,28 +9,27 @@ import ( mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" ) -// =============================== -// System-Assigned Managed Identity -// =============================== +// This example demonstrates how to create the managed identity client with system assigned +// // A system-assigned managed identity is enabled directly on an Azure resource (like a VM, App Service, or Function). // Azure automatically creates this identity and ties it to the lifecycle of the resource — it gets deleted when the resource is deleted. // Use this when your app only needs one identity and doesn’t need to share it across services. // Learn more: // https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#system-assigned-managed-identity -// ============================= -// User-Assigned Managed Identity -// ============================= -// A user-assigned managed identity is a standalone Azure resource that can be assigned to one or more Azure resources. -// User-assigned identities: https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#user-assigned-managed-identity - -func ExampleNew() { - +func ExampleNew_systemAssigned() { systemAssignedClient, err := mi.New(mi.SystemAssigned()) if err != nil { fmt.Printf("failed to create client with system-assigned identity: %v", err) } _ = systemAssignedClient // Use this client to authenticate to Azure services (e.g., Key Vault, Storage, etc.) +} + +// This example demonstrates how to create the managed identity client with user assigned +// +// A user-assigned managed identity is a standalone Azure resource that can be assigned to one or more Azure resources. +// User-assigned identities: https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#user-assigned-managed-identity +func ExampleNew_userAssigned() { clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID userAssignedClient, err := mi.New( @@ -42,16 +41,15 @@ func ExampleNew() { _ = userAssignedClient // Use this client for authentication when stable or shared identity is required } -// =========================================== -// Optional: Client Capabilities ("cp1", etc.) -// =========================================== +// Client Capabilities ("cp1", etc.) // 'cp1' is a capability that enables specific client behaviors — for example, // supporting Conditional Access policies that require additional client capabilities. // This is mostly relevant in scenarios where the identity is used to access resources // protected by policies like MFA or device compliance. // In most cases, you won't need to set this unless required by your Azure AD configuration. +// // Learn more: -// - Client capabilities: https://learn.microsoft.com/en-us/entra/msal/python/advanced/client-capabilities +// Client capabilities: https://learn.microsoft.com/en-us/entra/msal/python/advanced/client-capabilities func ExampleWithClientCapabilities() { clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID From d9b33cdf0597eb5040292ebca6fef9ff53cd913f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 13 May 2025 13:15:38 +0100 Subject: [PATCH 14/14] Updated with respected to pr comments --- apps/managedidentity/example_test.go | 2 +- apps/managedidentity/managedidentity.go | 5 ++- apps/managedidentity/managedidentity_test.go | 41 +++++++++++++++----- apps/managedidentity/servicefabric_test.go | 2 +- 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/apps/managedidentity/example_test.go b/apps/managedidentity/example_test.go index 9bcbd338..46ee6256 100644 --- a/apps/managedidentity/example_test.go +++ b/apps/managedidentity/example_test.go @@ -49,7 +49,7 @@ func ExampleNew_userAssigned() { // In most cases, you won't need to set this unless required by your Azure AD configuration. // // Learn more: -// Client capabilities: https://learn.microsoft.com/en-us/entra/msal/python/advanced/client-capabilities +// Client capabilities: https://learn.microsoft.com/entra/msal/python/advanced/client-capabilities func ExampleWithClientCapabilities() { clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 6a2eafeb..49f6fddd 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -208,8 +208,9 @@ func WithClientCapabilities(capabilities []string) ClientOption { return func(o *Client) { var filteredCapabilities []string for _, cap := range capabilities { - if cap != "" { - filteredCapabilities = append(filteredCapabilities, cap) + trimmedCap := strings.TrimSpace(cap) + if trimmedCap != "" { + filteredCapabilities = append(filteredCapabilities, trimmedCap) } } o.clientCapabilities = strings.Join(filteredCapabilities, ",") diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 2968ab9c..c46e6244 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -1212,6 +1212,29 @@ func TestRefreshInMultipleRequests(t *testing.T) { close(ch) } +func TestWithClientCapabilities_TrimsSpaces(t *testing.T) { + setEnvVars(t, AppService) + mockClient := mock.NewClient() + // Reset cache for clean test + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + capabilitiesWithSpaces := []string{" cp1", " cp2 ", " cp3 "} + expectedCapabilities := "cp1,cp2,cp3" + + client, err := New(SystemAssigned(), + WithHTTPClient(mockClient), + WithClientCapabilities(capabilitiesWithSpaces)) + if err != nil { + t.Fatal(err) + } + + if client.clientCapabilities != expectedCapabilities { + t.Errorf("WithClientCapabilities() did not trim spaces correctly, got: %s, want: %s", client.clientCapabilities, expectedCapabilities) + } +} + // TestAppServiceWithClaimsAndBadAccessToken tests the scenario where claims are passed // and a bad access token is retrieved from the cache func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) { @@ -1240,7 +1263,7 @@ func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) { client, err := New(SystemAssigned(), WithHTTPClient(mockClient), - WithClientCapabilities([]string{"c1", "c2"})) + WithClientCapabilities([]string{"cp1", "cp2"})) if err != nil { t.Fatal(err) } @@ -1262,20 +1285,20 @@ func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) { t.Fatalf("AcquireToken failed: %v", err) } - localUrlQuerry := localUrl.Query() + localUrlQuery := localUrl.Query() - if localUrlQuerry.Get(apiVersionQueryParameterName) != appServiceAPIVersion { - t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, localUrlQuerry.Get(apiVersionQueryParameterName)) + if localUrlQuery.Get(apiVersionQueryParameterName) != appServiceAPIVersion { + t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, localUrlQuery.Get(apiVersionQueryParameterName)) } - if r := localUrlQuerry.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") { + if r := localUrlQuery.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") { t.Fatal("suffix /.default was not removed.") } - if localUrlQuerry.Get("xms_cc") != "c1,c2" { - t.Fatalf("Expected client capabilities %q, got %q", "c1,c2", localUrlQuerry.Get("xms_cc")) + if localUrlQuery.Get("xms_cc") != "cp1,cp2" { + t.Fatalf("Expected client capabilities %q, got %q", "cp1,cp2", localUrlQuery.Get("xms_cc")) } hash := sha256.Sum256([]byte(token)) - if localUrlQuerry.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) { - t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuerry.Get("token_sha256_to_refresh")) + if localUrlQuery.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) { + t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuery.Get("token_sha256_to_refresh")) } // Verify token was obtained successfully if result.AccessToken != token { diff --git a/apps/managedidentity/servicefabric_test.go b/apps/managedidentity/servicefabric_test.go index 028f36ad..f8063655 100644 --- a/apps/managedidentity/servicefabric_test.go +++ b/apps/managedidentity/servicefabric_test.go @@ -99,7 +99,7 @@ func TestServiceFabricAcquireTokenReturnsTokenSuccess(t *testing.T) { } } -// TestAppServiceWithClaimsAndBadAccessToken tests the scenario where claims are passed +// TestServiceFabricWithClaimsAndBadAccessToken tests the scenario where claims are passed // and a bad access token is retrieved from the cache func TestServiceFabricWithClaimsAndBadAccessToken(t *testing.T) { setEnvVars(t, ServiceFabric)