@@ -11,6 +11,8 @@ package managedidentity
1111
1212import (
1313 "context"
14+ "crypto/sha256"
15+ "encoding/hex"
1416 "encoding/json"
1517 "fmt"
1618 "io"
@@ -82,7 +84,7 @@ const (
8284 tokenName = "Tokens"
8385
8486 // App Service
85- appServiceAPIVersion = "2019-08-01 "
87+ appServiceAPIVersion = "2025-03-30 "
8688
8789 // AzureML
8890 azureMLAPIVersion = "2017-09-01"
@@ -178,6 +180,7 @@ type Client struct {
178180 authParams authority.AuthParams
179181 retryPolicyEnabled bool
180182 canRefresh * atomic.Value
183+ clientCapabilities []string
181184}
182185
183186type AcquireTokenOptions struct {
@@ -192,14 +195,34 @@ type AcquireTokenOption func(o *AcquireTokenOptions)
192195// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
193196func WithClaims (claims string ) AcquireTokenOption {
194197 return func (o * AcquireTokenOptions ) {
195- o .claims = claims
198+ if claims != "" {
199+ o .claims = claims
200+ }
201+ }
202+ }
203+
204+ // WithClientCapabilities sets the client capabilities to be used in the request.
205+ // This is used to enable specific features or behaviors in the token request.
206+ // The capabilities are passed as a slice of strings, and empty strings are filtered out.
207+ func WithClientCapabilities (capabilities []string ) ClientOption {
208+ return func (o * Client ) {
209+ var filteredCapabilities []string
210+ for _ , cap := range capabilities {
211+ if cap != "" {
212+ filteredCapabilities = append (filteredCapabilities , cap )
213+ }
214+ }
215+ o .clientCapabilities = filteredCapabilities
196216 }
197217}
198218
199219// WithHTTPClient allows for a custom HTTP client to be set.
220+ // if nil, the default HTTP client will be used.
200221func WithHTTPClient (httpClient ops.HTTPClient ) ClientOption {
201222 return func (c * Client ) {
202- c .httpClient = httpClient
223+ if httpClient != nil {
224+ c .httpClient = httpClient
225+ }
203226 }
204227}
205228
@@ -323,28 +346,30 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
323346 }
324347 c .authParams .Scopes = []string {resource }
325348
326- // ignore cached access tokens when given claims
327- if o .claims == "" {
328- stResp , err := cacheManager .Read (ctx , c .authParams )
329- if err != nil {
330- return AuthResult {}, err
331- }
332- ar , err := base .AuthResultFromStorage (stResp )
333- if err == nil {
349+ stResp , err := cacheManager .Read (ctx , c .authParams )
350+ if err != nil {
351+ return AuthResult {}, err
352+ }
353+ ar , err := base .AuthResultFromStorage (stResp )
354+ if err == nil {
355+ if o .claims != "" {
356+ // When the claims are set, we need to passon bad/old token
357+ return c .getToken (ctx , resource , ar .AccessToken )
358+ } else {
334359 if ! stResp .AccessToken .RefreshOn .T .IsZero () && ! stResp .AccessToken .RefreshOn .T .After (now ()) && c .canRefresh .CompareAndSwap (false , true ) {
335360 defer c .canRefresh .Store (false )
336- if tr , er := c .getToken (ctx , resource ); er == nil {
361+ if tr , er := c .getToken (ctx , resource , o . claims ); er == nil {
337362 return tr , nil
338363 }
339364 }
340365 ar .AccessToken , err = c .authParams .AuthnScheme .FormatAccessToken (ar .AccessToken )
341366 return ar , err
342367 }
343368 }
344- return c .getToken (ctx , resource )
369+ return c .getToken (ctx , resource , "" )
345370}
346371
347- func (c Client ) getToken (ctx context.Context , resource string ) (AuthResult , error ) {
372+ func (c Client ) getToken (ctx context.Context , resource string , badToken string ) (AuthResult , error ) {
348373 switch c .source {
349374 case AzureArc :
350375 return c .acquireTokenForAzureArc (ctx , resource )
@@ -355,16 +380,16 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro
355380 case DefaultToIMDS :
356381 return c .acquireTokenForIMDS (ctx , resource )
357382 case AppService :
358- return c .acquireTokenForAppService (ctx , resource )
383+ return c .acquireTokenForAppService (ctx , resource , badToken )
359384 case ServiceFabric :
360385 return c .acquireTokenForServiceFabric (ctx , resource )
361386 default :
362387 return AuthResult {}, fmt .Errorf ("unsupported source %q" , c .source )
363388 }
364389}
365390
366- func (c Client ) acquireTokenForAppService (ctx context.Context , resource string ) (AuthResult , error ) {
367- req , err := createAppServiceAuthRequest (ctx , c .miType , resource )
391+ func (c Client ) acquireTokenForAppService (ctx context.Context , resource string , badToken string ) (AuthResult , error ) {
392+ req , err := createAppServiceAuthRequest (ctx , c .miType , resource , badToken , c . clientCapabilities )
368393 if err != nil {
369394 return AuthResult {}, err
370395 }
@@ -569,16 +594,27 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto
569594 return r , err
570595}
571596
572- func createAppServiceAuthRequest (ctx context.Context , id ID , resource string ) (* http.Request , error ) {
597+ func createAppServiceAuthRequest (ctx context.Context , id ID , resource string , badToken string , cc [] string ) (* http.Request , error ) {
573598 identityEndpoint := os .Getenv (identityEndpointEnvVar )
574599 req , err := http .NewRequestWithContext (ctx , http .MethodGet , identityEndpoint , nil )
575600 if err != nil {
576601 return nil , err
577602 }
578603 req .Header .Set ("X-IDENTITY-HEADER" , os .Getenv (identityHeaderEnvVar ))
604+
579605 q := req .URL .Query ()
580606 q .Set ("api-version" , appServiceAPIVersion )
581607 q .Set ("resource" , resource )
608+
609+ if badToken != "" {
610+ hash := sha256 .Sum256 ([]byte (badToken ))
611+ q .Set ("token_sha256_to_refresh" , hex .EncodeToString (hash [:]))
612+ }
613+
614+ if len (cc ) > 0 {
615+ q .Set ("xms_cc" , strings .Join (cc , "," ))
616+ }
617+
582618 switch t := id .(type ) {
583619 case UserAssignedClientID :
584620 q .Set (miQueryParameterClientId , string (t ))
0 commit comments