Skip to content

Commit 4742a02

Browse files
committed
Update managedidentity.go
1 parent e3d08ad commit 4742a02

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

apps/managedidentity/managedidentity.go

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ func WithClaims(claims string) AcquireTokenOption {
202202
}
203203

204204
// WithClientCapabilities sets the client capabilities to be used in the request.
205-
// For details see https://learn.microsoft.com/en-us/entra/identity/conditional-access/concept-continuous-access-evaluation
205+
// For details see https://learn.microsoft.com/entra/identity/conditional-access/concept-continuous-access-evaluation
206206
// The capabilities are passed as a slice of strings, and empty strings are filtered out.
207207
func WithClientCapabilities(capabilities []string) ClientOption {
208208
return func(o *Client) {
@@ -355,21 +355,20 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
355355
if o.claims != "" {
356356
// When the claims are set, we need to pass on revoked token to MSIv1 (AppService, ServiceFabric)
357357
return c.getToken(ctx, resource, ar.AccessToken)
358-
} else {
359-
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) {
360-
defer c.canRefresh.Store(false)
361-
if tr, er := c.getToken(ctx, resource, o.claims); er == nil {
362-
return tr, nil
363-
}
358+
}
359+
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) {
360+
defer c.canRefresh.Store(false)
361+
if tr, er := c.getToken(ctx, resource, ""); er == nil {
362+
return tr, nil
364363
}
365-
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
366-
return ar, err
367364
}
365+
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
366+
return ar, err
368367
}
369368
return c.getToken(ctx, resource, "")
370369
}
371370

372-
func (c Client) getToken(ctx context.Context, resource string, badToken string) (AuthResult, error) {
371+
func (c Client) getToken(ctx context.Context, resource string, revokedToken string) (AuthResult, error) {
373372
switch c.source {
374373
case AzureArc:
375374
return c.acquireTokenForAzureArc(ctx, resource)
@@ -380,16 +379,16 @@ func (c Client) getToken(ctx context.Context, resource string, badToken string)
380379
case DefaultToIMDS:
381380
return c.acquireTokenForIMDS(ctx, resource)
382381
case AppService:
383-
return c.acquireTokenForAppService(ctx, resource, badToken)
382+
return c.acquireTokenForAppService(ctx, resource, revokedToken)
384383
case ServiceFabric:
385384
return c.acquireTokenForServiceFabric(ctx, resource)
386385
default:
387386
return AuthResult{}, fmt.Errorf("unsupported source %q", c.source)
388387
}
389388
}
390389

391-
func (c Client) acquireTokenForAppService(ctx context.Context, resource string, badToken string) (AuthResult, error) {
392-
req, err := createAppServiceAuthRequest(ctx, c.miType, resource, badToken, c.clientCapabilities)
390+
func (c Client) acquireTokenForAppService(ctx context.Context, resource string, revokedToken string) (AuthResult, error) {
391+
req, err := createAppServiceAuthRequest(ctx, c.miType, resource, revokedToken, c.clientCapabilities)
393392
if err != nil {
394393
return AuthResult{}, err
395394
}
@@ -594,7 +593,7 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto
594593
return r, err
595594
}
596595

597-
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, badToken string, cc []string) (*http.Request, error) {
596+
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, revokedToken string, cc []string) (*http.Request, error) {
598597
identityEndpoint := os.Getenv(identityEndpointEnvVar)
599598
req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil)
600599
if err != nil {
@@ -606,8 +605,8 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, ba
606605
q.Set("api-version", appServiceAPIVersion)
607606
q.Set("resource", resource)
608607

609-
if badToken != "" {
610-
hash := sha256.Sum256([]byte(badToken))
608+
if revokedToken != "" {
609+
hash := sha256.Sum256([]byte(revokedToken))
611610
q.Set("token_sha256_to_refresh", hex.EncodeToString(hash[:]))
612611
}
613612

0 commit comments

Comments
 (0)