Skip to content

Commit c7b8d0a

Browse files
committed
Added support for Token revocation support
1 parent 4900473 commit c7b8d0a

File tree

2 files changed

+192
-18
lines changed

2 files changed

+192
-18
lines changed

apps/managedidentity/managedidentity.go

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ package managedidentity
1111

1212
import (
1313
"context"
14+
"crypto/sha256"
15+
"encoding/hex"
1416
"encoding/json"
1517
"fmt"
1618
"io"
@@ -82,7 +84,7 @@ const (
8284
tokenName = "Tokens"
8385

8486
// App Service
85-
appServiceAPIVersion = "2019-08-01"
87+
appServiceAPIVersion = "2025-03-30"
8688

8789
// AzureML
8890
azureMLAPIVersion = "2017-09-01"
@@ -178,6 +180,7 @@ type Client struct {
178180
authParams authority.AuthParams
179181
retryPolicyEnabled bool
180182
canRefresh *atomic.Value
183+
clientCapabilities []string
181184
}
182185

183186
type AcquireTokenOptions struct {
@@ -192,14 +195,34 @@ type AcquireTokenOption func(o *AcquireTokenOptions)
192195
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
193196
func WithClaims(claims string) AcquireTokenOption {
194197
return func(o *AcquireTokenOptions) {
195-
o.claims = claims
198+
if claims != "" {
199+
o.claims = claims
200+
}
201+
}
202+
}
203+
204+
// WithClientCapabilities sets the client capabilities to be used in the request.
205+
// This is used to enable specific features or behaviors in the token request.
206+
// The capabilities are passed as a slice of strings, and empty strings are filtered out.
207+
func WithClientCapabilities(capabilities []string) ClientOption {
208+
return func(o *Client) {
209+
var filteredCapabilities []string
210+
for _, cap := range capabilities {
211+
if cap != "" {
212+
filteredCapabilities = append(filteredCapabilities, cap)
213+
}
214+
}
215+
o.clientCapabilities = filteredCapabilities
196216
}
197217
}
198218

199219
// WithHTTPClient allows for a custom HTTP client to be set.
220+
// if nil, the default HTTP client will be used.
200221
func WithHTTPClient(httpClient ops.HTTPClient) ClientOption {
201222
return func(c *Client) {
202-
c.httpClient = httpClient
223+
if httpClient != nil {
224+
c.httpClient = httpClient
225+
}
203226
}
204227
}
205228

@@ -323,28 +346,30 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
323346
}
324347
c.authParams.Scopes = []string{resource}
325348

326-
// ignore cached access tokens when given claims
327-
if o.claims == "" {
328-
stResp, err := cacheManager.Read(ctx, c.authParams)
329-
if err != nil {
330-
return AuthResult{}, err
331-
}
332-
ar, err := base.AuthResultFromStorage(stResp)
333-
if err == nil {
349+
stResp, err := cacheManager.Read(ctx, c.authParams)
350+
if err != nil {
351+
return AuthResult{}, err
352+
}
353+
ar, err := base.AuthResultFromStorage(stResp)
354+
if err == nil {
355+
if o.claims != "" {
356+
// When the claims are set, we need to passon bad/old token
357+
return c.getToken(ctx, resource, ar.AccessToken)
358+
} else {
334359
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) {
335360
defer c.canRefresh.Store(false)
336-
if tr, er := c.getToken(ctx, resource); er == nil {
361+
if tr, er := c.getToken(ctx, resource, o.claims); er == nil {
337362
return tr, nil
338363
}
339364
}
340365
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
341366
return ar, err
342367
}
343368
}
344-
return c.getToken(ctx, resource)
369+
return c.getToken(ctx, resource, "")
345370
}
346371

347-
func (c Client) getToken(ctx context.Context, resource string) (AuthResult, error) {
372+
func (c Client) getToken(ctx context.Context, resource string, badToken string) (AuthResult, error) {
348373
switch c.source {
349374
case AzureArc:
350375
return c.acquireTokenForAzureArc(ctx, resource)
@@ -355,16 +380,16 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro
355380
case DefaultToIMDS:
356381
return c.acquireTokenForIMDS(ctx, resource)
357382
case AppService:
358-
return c.acquireTokenForAppService(ctx, resource)
383+
return c.acquireTokenForAppService(ctx, resource, badToken)
359384
case ServiceFabric:
360385
return c.acquireTokenForServiceFabric(ctx, resource)
361386
default:
362387
return AuthResult{}, fmt.Errorf("unsupported source %q", c.source)
363388
}
364389
}
365390

366-
func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (AuthResult, error) {
367-
req, err := createAppServiceAuthRequest(ctx, c.miType, resource)
391+
func (c Client) acquireTokenForAppService(ctx context.Context, resource string, badToken string) (AuthResult, error) {
392+
req, err := createAppServiceAuthRequest(ctx, c.miType, resource, badToken, c.clientCapabilities)
368393
if err != nil {
369394
return AuthResult{}, err
370395
}
@@ -569,16 +594,27 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto
569594
return r, err
570595
}
571596

572-
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) {
597+
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, badToken string, cc []string) (*http.Request, error) {
573598
identityEndpoint := os.Getenv(identityEndpointEnvVar)
574599
req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil)
575600
if err != nil {
576601
return nil, err
577602
}
578603
req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar))
604+
579605
q := req.URL.Query()
580606
q.Set("api-version", appServiceAPIVersion)
581607
q.Set("resource", resource)
608+
609+
if badToken != "" {
610+
hash := sha256.Sum256([]byte(badToken))
611+
q.Set("token_sha256_to_refresh", hex.EncodeToString(hash[:]))
612+
}
613+
614+
if len(cc) > 0 {
615+
q.Set("xms_cc", strings.Join(cc, ","))
616+
}
617+
582618
switch t := id.(type) {
583619
case UserAssignedClientID:
584620
q.Set(miQueryParameterClientId, string(t))

apps/managedidentity/managedidentity_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package managedidentity
55
import (
66
"bytes"
77
"context"
8+
"crypto/sha256"
9+
"encoding/hex"
810
"encoding/json"
911
"fmt"
1012
"io"
@@ -696,6 +698,71 @@ func TestAppServiceAcquireTokenReturnsTokenSuccess(t *testing.T) {
696698
}
697699
}
698700

701+
// TestAppServiceWithClientCapabilities tests the scenario when App Service includes client capabilities
702+
func TestAppServiceWithClientCapabilities(t *testing.T) {
703+
setEnvVars(t, AppService)
704+
705+
testCases := []struct {
706+
name string
707+
expectError bool
708+
expectedStatusCode int
709+
expectedToken string
710+
expectedCapabilities string
711+
capabilities []string
712+
}{
713+
{
714+
name: "Token Request with Client Capabilities",
715+
expectError: false,
716+
expectedStatusCode: http.StatusOK,
717+
expectedToken: token,
718+
expectedCapabilities: "c1,c2",
719+
capabilities: []string{"c1", "c2"},
720+
},
721+
}
722+
723+
for _, testCase := range testCases {
724+
t.Run(testCase.name, func(t *testing.T) {
725+
mockClient := mock.NewClient()
726+
localUrl := &url.URL{}
727+
responseBody, err := getSuccessfulResponse(resource, false)
728+
if err != nil {
729+
t.Fatalf(errorFormingJsonResponse, err.Error())
730+
}
731+
mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK),
732+
mock.WithBody(responseBody),
733+
mock.WithCallback(func(r *http.Request) {
734+
localUrl = r.URL
735+
}))
736+
client, err := New(SystemAssigned(), WithHTTPClient(mockClient), WithClientCapabilities(testCase.capabilities))
737+
if err != nil {
738+
t.Fatal(err)
739+
}
740+
741+
result, err := client.AcquireToken(context.Background(), resource)
742+
if testCase.expectError {
743+
if err == nil {
744+
t.Fatal("Expected error but got nil")
745+
}
746+
var callErr errors.CallErr
747+
if !errors.As(err, &callErr) {
748+
t.Fatalf("Expected error of type CallErr, got %T", err)
749+
}
750+
} else {
751+
if err != nil {
752+
t.Fatalf("AcquireToken failed: %v", err)
753+
}
754+
if result.AccessToken != testCase.expectedToken {
755+
t.Fatalf("Expected access token %q, got %q", testCase.expectedToken, result.AccessToken)
756+
}
757+
localUrlQuerry := localUrl.Query()
758+
if localUrlQuerry.Get("xms_cc") != testCase.expectedCapabilities {
759+
t.Fatalf("Expected client capabilities %q, got %q", testCase.expectedCapabilities, localUrlQuerry.Get("xms_cc"))
760+
}
761+
}
762+
})
763+
}
764+
}
765+
699766
func TestAzureMLAcquireTokenReturnsTokenSuccess(t *testing.T) {
700767
defaultClientID := "A"
701768
t.Setenv("DEFAULT_IDENTITY_CLIENT_ID", defaultClientID)
@@ -1209,3 +1276,74 @@ func TestRefreshInMultipleRequests(t *testing.T) {
12091276
}
12101277
close(ch)
12111278
}
1279+
1280+
// TestAppServiceWithClaimsAndBadAccessToken tests the scenario where claims are passed
1281+
// and a bad access token is retrieved from the cache
1282+
func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) {
1283+
setEnvVars(t, AppService)
1284+
localUrl := &url.URL{}
1285+
mockClient := mock.NewClient()
1286+
// Second response is a successful token response after retrying with claims
1287+
responseBody, err := getSuccessfulResponse(resource, false)
1288+
if err != nil {
1289+
t.Fatalf(errorFormingJsonResponse, err.Error())
1290+
}
1291+
mockClient.AppendResponse(
1292+
mock.WithHTTPStatusCode(http.StatusOK),
1293+
mock.WithBody(responseBody),
1294+
)
1295+
mockClient.AppendResponse(
1296+
mock.WithHTTPStatusCode(http.StatusOK),
1297+
mock.WithBody(responseBody),
1298+
mock.WithCallback(func(r *http.Request) {
1299+
localUrl = r.URL
1300+
}))
1301+
// Reset cache for clean test
1302+
before := cacheManager
1303+
defer func() { cacheManager = before }()
1304+
cacheManager = storage.New(nil)
1305+
1306+
client, err := New(SystemAssigned(),
1307+
WithHTTPClient(mockClient),
1308+
WithClientCapabilities([]string{"c1", "c2"}))
1309+
if err != nil {
1310+
t.Fatal(err)
1311+
}
1312+
1313+
// Call AcquireToken which should trigger token revocation flow
1314+
result, err := client.AcquireToken(context.Background(), resource)
1315+
if err != nil {
1316+
t.Fatalf("AcquireToken failed: %v", err)
1317+
}
1318+
1319+
// Verify token was obtained successfully
1320+
if result.AccessToken != token {
1321+
t.Fatalf("Expected access token %q, got %q", token, result.AccessToken)
1322+
}
1323+
1324+
// Call AcquireToken which should trigger token revocation flow
1325+
result, err = client.AcquireToken(context.Background(), resource, WithClaims("dummyClaims"))
1326+
if err != nil {
1327+
t.Fatalf("AcquireToken failed: %v", err)
1328+
}
1329+
1330+
localUrlQuerry := localUrl.Query()
1331+
1332+
if localUrlQuerry.Get(apiVersionQueryParameterName) != appServiceAPIVersion {
1333+
t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, localUrlQuerry.Get(apiVersionQueryParameterName))
1334+
}
1335+
if r := localUrlQuerry.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") {
1336+
t.Fatal("suffix /.default was not removed.")
1337+
}
1338+
if localUrlQuerry.Get("xms_cc") != "c1,c2" {
1339+
t.Fatalf("Expected client capabilities %q, got %q", "c1,c2", localUrlQuerry.Get("xms_cc"))
1340+
}
1341+
hash := sha256.Sum256([]byte(token))
1342+
if localUrlQuerry.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) {
1343+
t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuerry.Get("token_sha256_to_refresh"))
1344+
}
1345+
// Verify token was obtained successfully
1346+
if result.AccessToken != token {
1347+
t.Fatalf("Expected access token %q, got %q", token, result.AccessToken)
1348+
}
1349+
}

0 commit comments

Comments
 (0)