Skip to content
65 changes: 65 additions & 0 deletions apps/managedidentity/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package managedidentity_test

import (
"fmt"

mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
)

// This example demonstrates how to create the managed identity client with system assigned
//
// A system-assigned managed identity is enabled directly on an Azure resource (like a VM, App Service, or Function).
// Azure automatically creates this identity and ties it to the lifecycle of the resource — it gets deleted when the resource is deleted.
// Use this when your app only needs one identity and doesn’t need to share it across services.
// Learn more:
// https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#system-assigned-managed-identity
func ExampleNew_systemAssigned() {
systemAssignedClient, err := mi.New(mi.SystemAssigned())
if err != nil {
fmt.Printf("failed to create client with system-assigned identity: %v", err)
}
_ = systemAssignedClient // Use this client to authenticate to Azure services (e.g., Key Vault, Storage, etc.)

}

// This example demonstrates how to create the managed identity client with user assigned
//
// A user-assigned managed identity is a standalone Azure resource that can be assigned to one or more Azure resources.
// User-assigned identities: https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#user-assigned-managed-identity
func ExampleNew_userAssigned() {
clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID

userAssignedClient, err := mi.New(
mi.UserAssignedClientID(clientID),
)
if err != nil {
fmt.Printf("failed to create client with user-assigned identity: %v", err)
}
_ = userAssignedClient // Use this client for authentication when stable or shared identity is required
}

// Client Capabilities ("cp1", etc.)
// 'cp1' is a capability that enables specific client behaviors — for example,
// supporting Conditional Access policies that require additional client capabilities.
// This is mostly relevant in scenarios where the identity is used to access resources
// protected by policies like MFA or device compliance.
// In most cases, you won't need to set this unless required by your Azure AD configuration.
//
// Learn more:
// Client capabilities: https://learn.microsoft.com/entra/msal/python/advanced/client-capabilities
func ExampleWithClientCapabilities() {
clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID

userAssignedClient, err := mi.New(
mi.UserAssignedClientID(clientID),

mi.WithClientCapabilities([]string{"cp1"}),
)
if err != nil {
fmt.Printf("failed to create client with user-assigned identity: %v", err)
}
_ = userAssignedClient // Use this client for authentication when stable or shared identity is required
}
94 changes: 68 additions & 26 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ package managedidentity

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -82,7 +84,7 @@ const (
tokenName = "Tokens"

// App Service
appServiceAPIVersion = "2019-08-01"
appServiceAPIVersion = "2025-03-30"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good and matches the app service new version. We just want to make sure not to merge this PR yet, as App Service rollout is still happening.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work in MSAL .net about guarding the release or the version blocking ?


// AzureML
azureMLAPIVersion = "2017-09-01"
Expand Down Expand Up @@ -178,6 +180,7 @@ type Client struct {
authParams authority.AuthParams
retryPolicyEnabled bool
canRefresh *atomic.Value
clientCapabilities string
}

type AcquireTokenOptions struct {
Expand All @@ -192,14 +195,35 @@ type AcquireTokenOption func(o *AcquireTokenOptions)
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
func WithClaims(claims string) AcquireTokenOption {
return func(o *AcquireTokenOptions) {
o.claims = claims
if claims != "" {
o.claims = claims
}
}
}

// WithClientCapabilities sets the client capabilities to be used in the request.
// For details see https://learn.microsoft.com/entra/identity/conditional-access/concept-continuous-access-evaluation
// The capabilities are passed as a slice of strings, and empty strings are filtered out.
func WithClientCapabilities(capabilities []string) ClientOption {
return func(o *Client) {
var filteredCapabilities []string
for _, cap := range capabilities {
trimmedCap := strings.TrimSpace(cap)
if trimmedCap != "" {
filteredCapabilities = append(filteredCapabilities, trimmedCap)
}
}
o.clientCapabilities = strings.Join(filteredCapabilities, ",")
}
}

// WithHTTPClient allows for a custom HTTP client to be set.
// if nil, the default HTTP client will be used.
func WithHTTPClient(httpClient ops.HTTPClient) ClientOption {
return func(c *Client) {
c.httpClient = httpClient
if httpClient != nil {
c.httpClient = httpClient
}
}
}

Expand Down Expand Up @@ -323,28 +347,29 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
}
c.authParams.Scopes = []string{resource}

// ignore cached access tokens when given claims
if o.claims == "" {
stResp, err := cacheManager.Read(ctx, c.authParams)
if err != nil {
return AuthResult{}, err
stResp, err := cacheManager.Read(ctx, c.authParams)
if err != nil {
return AuthResult{}, err
}
ar, err := base.AuthResultFromStorage(stResp)
if err == nil {
if o.claims != "" {
// When the claims are set, we need to pass on revoked token to MSIv1 (AppService, ServiceFabric)
return c.getToken(ctx, resource, ar.AccessToken)
}
ar, err := base.AuthResultFromStorage(stResp)
if err == nil {
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) {
defer c.canRefresh.Store(false)
if tr, er := c.getToken(ctx, resource); er == nil {
return tr, nil
}
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) {
defer c.canRefresh.Store(false)
if tr, er := c.getToken(ctx, resource, ""); er == nil {
return tr, nil
}
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
return ar, err
}
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
return ar, err
}
return c.getToken(ctx, resource)
return c.getToken(ctx, resource, "")
}

func (c Client) getToken(ctx context.Context, resource string) (AuthResult, error) {
func (c Client) getToken(ctx context.Context, resource string, revokedToken string) (AuthResult, error) {
switch c.source {
case AzureArc:
return c.acquireTokenForAzureArc(ctx, resource)
Expand All @@ -355,16 +380,16 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro
case DefaultToIMDS:
return c.acquireTokenForIMDS(ctx, resource)
case AppService:
return c.acquireTokenForAppService(ctx, resource)
return c.acquireTokenForAppService(ctx, resource, revokedToken)
case ServiceFabric:
return c.acquireTokenForServiceFabric(ctx, resource)
return c.acquireTokenForServiceFabric(ctx, resource, revokedToken)
default:
return AuthResult{}, fmt.Errorf("unsupported source %q", c.source)
}
}

func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (AuthResult, error) {
req, err := createAppServiceAuthRequest(ctx, c.miType, resource)
func (c Client) acquireTokenForAppService(ctx context.Context, resource string, revokedToken string) (AuthResult, error) {
req, err := createAppServiceAuthRequest(ctx, c.miType, resource, revokedToken, c.clientCapabilities)
if err != nil {
return AuthResult{}, err
}
Expand Down Expand Up @@ -411,8 +436,8 @@ func (c Client) acquireTokenForAzureML(ctx context.Context, resource string) (Au
return authResultFromToken(c.authParams, tokenResponse)
}

func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string) (AuthResult, error) {
req, err := createServiceFabricAuthRequest(ctx, resource)
func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string, revokedToken string) (AuthResult, error) {
req, err := createServiceFabricAuthRequest(ctx, resource, revokedToken, c.clientCapabilities)
if err != nil {
return AuthResult{}, err
}
Expand Down Expand Up @@ -569,16 +594,26 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto
return r, err
}

func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) {
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, revokedToken string, cc string) (*http.Request, error) {
identityEndpoint := os.Getenv(identityEndpointEnvVar)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar))

q := req.URL.Query()
q.Set("api-version", appServiceAPIVersion)
q.Set("resource", resource)

if revokedToken != "" {
q.Set("token_sha256_to_refresh", convertTokenToSHA256HashString(revokedToken))
}

if len(cc) > 0 {
q.Set("xms_cc", cc)
}

switch t := id.(type) {
case UserAssignedClientID:
q.Set(miQueryParameterClientId, string(t))
Expand All @@ -594,6 +629,13 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*
return req, nil
}

func convertTokenToSHA256HashString(revokedToken string) string {
hash := sha256.New()
hash.Write([]byte(revokedToken))
hashBytes := hash.Sum(nil)
return hex.EncodeToString(hashBytes)
}

func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) {
msiEndpoint, err := url.Parse(imdsDefaultEndpoint)
if err != nil {
Expand Down
125 changes: 124 additions & 1 deletion apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package managedidentity
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -964,7 +966,7 @@ func TestAzureArcErrors(t *testing.T) {
},
{
name: "Invalid file path",
headerValue: "Basic realm=" + filepath.Join("path", "to", secretKey),
headerValue: basicRealm + filepath.Join("path", "to", secretKey),
expectedError: "invalid file path, expected " + testCaseFilePath + ", got " + filepath.Join("path", "to"),
},
{
Expand Down Expand Up @@ -1209,3 +1211,124 @@ func TestRefreshInMultipleRequests(t *testing.T) {
}
close(ch)
}

func TestWithClientCapabilities_TrimsSpaces(t *testing.T) {
setEnvVars(t, AppService)
mockClient := mock.NewClient()
// Reset cache for clean test
before := cacheManager
defer func() { cacheManager = before }()
cacheManager = storage.New(nil)

capabilitiesWithSpaces := []string{" cp1", " cp2 ", " cp3 "}
expectedCapabilities := "cp1,cp2,cp3"

client, err := New(SystemAssigned(),
WithHTTPClient(mockClient),
WithClientCapabilities(capabilitiesWithSpaces))
if err != nil {
t.Fatal(err)
}

if client.clientCapabilities != expectedCapabilities {
t.Errorf("WithClientCapabilities() did not trim spaces correctly, got: %s, want: %s", client.clientCapabilities, expectedCapabilities)
}
}

// TestAppServiceWithClaimsAndBadAccessToken tests the scenario where claims are passed
// and a bad access token is retrieved from the cache
func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) {
setEnvVars(t, AppService)
localUrl := &url.URL{}
mockClient := mock.NewClient()
// Second response is a successful token response after retrying with claims
responseBody, err := getSuccessfulResponse(resource, false)
if err != nil {
t.Fatalf(errorFormingJsonResponse, err.Error())
}
mockClient.AppendResponse(
mock.WithHTTPStatusCode(http.StatusOK),
mock.WithBody(responseBody),
)
mockClient.AppendResponse(
mock.WithHTTPStatusCode(http.StatusOK),
mock.WithBody(responseBody),
mock.WithCallback(func(r *http.Request) {
localUrl = r.URL
}))
// Reset cache for clean test
before := cacheManager
defer func() { cacheManager = before }()
cacheManager = storage.New(nil)

client, err := New(SystemAssigned(),
WithHTTPClient(mockClient),
WithClientCapabilities([]string{"cp1", "cp2"}))
if err != nil {
t.Fatal(err)
}

// Call AcquireToken which should trigger token revocation flow
result, err := client.AcquireToken(context.Background(), resource)
if err != nil {
t.Fatalf("AcquireToken failed: %v", err)
}

// Verify token was obtained successfully
if result.AccessToken != token {
t.Fatalf("Expected access token %q, got %q", token, result.AccessToken)
}

// Call AcquireToken which should trigger token revocation flow
result, err = client.AcquireToken(context.Background(), resource, WithClaims("dummyClaims"))
if err != nil {
t.Fatalf("AcquireToken failed: %v", err)
}

localUrlQuery := localUrl.Query()

if localUrlQuery.Get(apiVersionQueryParameterName) != appServiceAPIVersion {
t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, localUrlQuery.Get(apiVersionQueryParameterName))
}
if r := localUrlQuery.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") {
t.Fatal("suffix /.default was not removed.")
}
if localUrlQuery.Get("xms_cc") != "cp1,cp2" {
t.Fatalf("Expected client capabilities %q, got %q", "cp1,cp2", localUrlQuery.Get("xms_cc"))
}
hash := sha256.Sum256([]byte(token))
if localUrlQuery.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) {
t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuery.Get("token_sha256_to_refresh"))
}
// Verify token was obtained successfully
if result.AccessToken != token {
t.Fatalf("Expected access token %q, got %q", token, result.AccessToken)
}
}

func TestConvertTokenToSHA256HashString(t *testing.T) {
tests := []struct {
token string
expectedHash string
}{
{
token: "test_token",
expectedHash: "cc0af97287543b65da2c7e1476426021826cab166f1e063ed012b855ff819656",
},
{
token: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~",
expectedHash: "01588d5a948b6c4facd47866877491b42866b5c10a4d342cf168e994101d352a",
},
{
token: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~",
expectedHash: "29c538690068a8ad1797a391bfe23e7fb817b601fc7b78288cb499ab8fd37947",
},
}

for _, test := range tests {
hash := convertTokenToSHA256HashString(test.token)
if hash != test.expectedHash {
t.Fatalf("for token %q, expected %q, got %q", test.token, test.expectedHash, hash)
}
}
}
Loading