-
Notifications
You must be signed in to change notification settings - Fork 99
Added support for Token revocation support #567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c7b8d0a
b15c36b
e3d08ad
4742a02
96554ad
16f873b
2052528
6817fc2
faaa17e
35be0ab
cc71d57
7d901e6
1c9fa26
d9b33cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package managedidentity_test | ||
|
||
import ( | ||
"fmt" | ||
|
||
mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" | ||
) | ||
|
||
// This example demonstrates how to create the managed identity client with system assigned | ||
// | ||
// A system-assigned managed identity is enabled directly on an Azure resource (like a VM, App Service, or Function). | ||
// Azure automatically creates this identity and ties it to the lifecycle of the resource — it gets deleted when the resource is deleted. | ||
// Use this when your app only needs one identity and doesn’t need to share it across services. | ||
// Learn more: | ||
// https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#system-assigned-managed-identity | ||
func ExampleNew_systemAssigned() { | ||
systemAssignedClient, err := mi.New(mi.SystemAssigned()) | ||
if err != nil { | ||
fmt.Printf("failed to create client with system-assigned identity: %v", err) | ||
} | ||
_ = systemAssignedClient // Use this client to authenticate to Azure services (e.g., Key Vault, Storage, etc.) | ||
|
||
} | ||
|
||
// This example demonstrates how to create the managed identity client with user assigned | ||
// | ||
// A user-assigned managed identity is a standalone Azure resource that can be assigned to one or more Azure resources. | ||
// User-assigned identities: https://learn.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview#user-assigned-managed-identity | ||
func ExampleNew_userAssigned() { | ||
clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID | ||
|
||
userAssignedClient, err := mi.New( | ||
mi.UserAssignedClientID(clientID), | ||
) | ||
if err != nil { | ||
fmt.Printf("failed to create client with user-assigned identity: %v", err) | ||
} | ||
_ = userAssignedClient // Use this client for authentication when stable or shared identity is required | ||
} | ||
|
||
// Client Capabilities ("cp1", etc.) | ||
// 'cp1' is a capability that enables specific client behaviors — for example, | ||
// supporting Conditional Access policies that require additional client capabilities. | ||
// This is mostly relevant in scenarios where the identity is used to access resources | ||
// protected by policies like MFA or device compliance. | ||
// In most cases, you won't need to set this unless required by your Azure AD configuration. | ||
// | ||
// Learn more: | ||
// Client capabilities: https://learn.microsoft.com/entra/msal/python/advanced/client-capabilities | ||
func ExampleWithClientCapabilities() { | ||
clientID := "your-user-assigned-client-id" // TODO: Replace with actual managed identity client ID | ||
|
||
userAssignedClient, err := mi.New( | ||
mi.UserAssignedClientID(clientID), | ||
|
||
mi.WithClientCapabilities([]string{"cp1"}), | ||
) | ||
if err != nil { | ||
fmt.Printf("failed to create client with user-assigned identity: %v", err) | ||
} | ||
_ = userAssignedClient // Use this client for authentication when stable or shared identity is required | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,8 @@ package managedidentity | |
|
||
import ( | ||
"context" | ||
"crypto/sha256" | ||
"encoding/hex" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
|
@@ -82,7 +84,7 @@ const ( | |
tokenName = "Tokens" | ||
|
||
// App Service | ||
appServiceAPIVersion = "2019-08-01" | ||
appServiceAPIVersion = "2025-03-30" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks good and matches the app service new version. We just want to make sure not to merge this PR yet, as App Service rollout is still happening. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does it work in MSAL .net about guarding the release or the version blocking ? |
||
|
||
// AzureML | ||
azureMLAPIVersion = "2017-09-01" | ||
|
@@ -178,6 +180,7 @@ type Client struct { | |
authParams authority.AuthParams | ||
retryPolicyEnabled bool | ||
canRefresh *atomic.Value | ||
clientCapabilities string | ||
} | ||
|
||
type AcquireTokenOptions struct { | ||
|
@@ -192,14 +195,35 @@ type AcquireTokenOption func(o *AcquireTokenOptions) | |
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. | ||
func WithClaims(claims string) AcquireTokenOption { | ||
return func(o *AcquireTokenOptions) { | ||
o.claims = claims | ||
if claims != "" { | ||
o.claims = claims | ||
} | ||
} | ||
} | ||
|
||
// WithClientCapabilities sets the client capabilities to be used in the request. | ||
// For details see https://learn.microsoft.com/entra/identity/conditional-access/concept-continuous-access-evaluation | ||
// The capabilities are passed as a slice of strings, and empty strings are filtered out. | ||
4gust marked this conversation as resolved.
Show resolved
Hide resolved
|
||
func WithClientCapabilities(capabilities []string) ClientOption { | ||
4gust marked this conversation as resolved.
Show resolved
Hide resolved
4gust marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return func(o *Client) { | ||
var filteredCapabilities []string | ||
for _, cap := range capabilities { | ||
trimmedCap := strings.TrimSpace(cap) | ||
if trimmedCap != "" { | ||
filteredCapabilities = append(filteredCapabilities, trimmedCap) | ||
} | ||
} | ||
o.clientCapabilities = strings.Join(filteredCapabilities, ",") | ||
4gust marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
// WithHTTPClient allows for a custom HTTP client to be set. | ||
// if nil, the default HTTP client will be used. | ||
func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { | ||
return func(c *Client) { | ||
c.httpClient = httpClient | ||
if httpClient != nil { | ||
c.httpClient = httpClient | ||
} | ||
} | ||
} | ||
|
||
|
@@ -323,28 +347,29 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac | |
} | ||
c.authParams.Scopes = []string{resource} | ||
|
||
// ignore cached access tokens when given claims | ||
if o.claims == "" { | ||
stResp, err := cacheManager.Read(ctx, c.authParams) | ||
if err != nil { | ||
return AuthResult{}, err | ||
stResp, err := cacheManager.Read(ctx, c.authParams) | ||
if err != nil { | ||
return AuthResult{}, err | ||
} | ||
ar, err := base.AuthResultFromStorage(stResp) | ||
if err == nil { | ||
if o.claims != "" { | ||
// When the claims are set, we need to pass on revoked token to MSIv1 (AppService, ServiceFabric) | ||
return c.getToken(ctx, resource, ar.AccessToken) | ||
} | ||
ar, err := base.AuthResultFromStorage(stResp) | ||
if err == nil { | ||
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) { | ||
defer c.canRefresh.Store(false) | ||
if tr, er := c.getToken(ctx, resource); er == nil { | ||
return tr, nil | ||
} | ||
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) { | ||
defer c.canRefresh.Store(false) | ||
if tr, er := c.getToken(ctx, resource, ""); er == nil { | ||
return tr, nil | ||
} | ||
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) | ||
return ar, err | ||
} | ||
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) | ||
return ar, err | ||
} | ||
return c.getToken(ctx, resource) | ||
return c.getToken(ctx, resource, "") | ||
} | ||
|
||
func (c Client) getToken(ctx context.Context, resource string) (AuthResult, error) { | ||
func (c Client) getToken(ctx context.Context, resource string, revokedToken string) (AuthResult, error) { | ||
switch c.source { | ||
case AzureArc: | ||
return c.acquireTokenForAzureArc(ctx, resource) | ||
|
@@ -355,16 +380,16 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro | |
case DefaultToIMDS: | ||
return c.acquireTokenForIMDS(ctx, resource) | ||
case AppService: | ||
return c.acquireTokenForAppService(ctx, resource) | ||
return c.acquireTokenForAppService(ctx, resource, revokedToken) | ||
case ServiceFabric: | ||
return c.acquireTokenForServiceFabric(ctx, resource) | ||
return c.acquireTokenForServiceFabric(ctx, resource, revokedToken) | ||
default: | ||
return AuthResult{}, fmt.Errorf("unsupported source %q", c.source) | ||
} | ||
} | ||
|
||
func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (AuthResult, error) { | ||
req, err := createAppServiceAuthRequest(ctx, c.miType, resource) | ||
func (c Client) acquireTokenForAppService(ctx context.Context, resource string, revokedToken string) (AuthResult, error) { | ||
req, err := createAppServiceAuthRequest(ctx, c.miType, resource, revokedToken, c.clientCapabilities) | ||
if err != nil { | ||
return AuthResult{}, err | ||
} | ||
|
@@ -411,8 +436,8 @@ func (c Client) acquireTokenForAzureML(ctx context.Context, resource string) (Au | |
return authResultFromToken(c.authParams, tokenResponse) | ||
} | ||
|
||
func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string) (AuthResult, error) { | ||
req, err := createServiceFabricAuthRequest(ctx, resource) | ||
func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string, revokedToken string) (AuthResult, error) { | ||
req, err := createServiceFabricAuthRequest(ctx, resource, revokedToken, c.clientCapabilities) | ||
if err != nil { | ||
return AuthResult{}, err | ||
} | ||
|
@@ -569,16 +594,26 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto | |
return r, err | ||
} | ||
|
||
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { | ||
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, revokedToken string, cc string) (*http.Request, error) { | ||
identityEndpoint := os.Getenv(identityEndpointEnvVar) | ||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar)) | ||
|
||
q := req.URL.Query() | ||
q.Set("api-version", appServiceAPIVersion) | ||
q.Set("resource", resource) | ||
|
||
if revokedToken != "" { | ||
q.Set("token_sha256_to_refresh", convertTokenToSHA256HashString(revokedToken)) | ||
} | ||
|
||
if len(cc) > 0 { | ||
q.Set("xms_cc", cc) | ||
} | ||
|
||
switch t := id.(type) { | ||
case UserAssignedClientID: | ||
q.Set(miQueryParameterClientId, string(t)) | ||
|
@@ -594,6 +629,13 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (* | |
return req, nil | ||
} | ||
|
||
func convertTokenToSHA256HashString(revokedToken string) string { | ||
hash := sha256.New() | ||
hash.Write([]byte(revokedToken)) | ||
hashBytes := hash.Sum(nil) | ||
return hex.EncodeToString(hashBytes) | ||
} | ||
|
||
func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { | ||
msiEndpoint, err := url.Parse(imdsDefaultEndpoint) | ||
if err != nil { | ||
|
Uh oh!
There was an error while loading. Please reload this page.