@@ -11,6 +11,8 @@ package managedidentity
11
11
12
12
import (
13
13
"context"
14
+ "crypto/sha256"
15
+ "encoding/hex"
14
16
"encoding/json"
15
17
"fmt"
16
18
"io"
@@ -82,7 +84,7 @@ const (
82
84
tokenName = "Tokens"
83
85
84
86
// App Service
85
- appServiceAPIVersion = "2019-08-01 "
87
+ appServiceAPIVersion = "2025-03-30 "
86
88
87
89
// AzureML
88
90
azureMLAPIVersion = "2017-09-01"
@@ -178,6 +180,7 @@ type Client struct {
178
180
authParams authority.AuthParams
179
181
retryPolicyEnabled bool
180
182
canRefresh * atomic.Value
183
+ clientCapabilities []string
181
184
}
182
185
183
186
type AcquireTokenOptions struct {
@@ -192,14 +195,34 @@ type AcquireTokenOption func(o *AcquireTokenOptions)
192
195
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
193
196
func WithClaims (claims string ) AcquireTokenOption {
194
197
return func (o * AcquireTokenOptions ) {
195
- o .claims = claims
198
+ if claims != "" {
199
+ o .claims = claims
200
+ }
201
+ }
202
+ }
203
+
204
+ // WithClientCapabilities sets the client capabilities to be used in the request.
205
+ // This is used to enable specific features or behaviors in the token request.
206
+ // The capabilities are passed as a slice of strings, and empty strings are filtered out.
207
+ func WithClientCapabilities (capabilities []string ) ClientOption {
208
+ return func (o * Client ) {
209
+ var filteredCapabilities []string
210
+ for _ , cap := range capabilities {
211
+ if cap != "" {
212
+ filteredCapabilities = append (filteredCapabilities , cap )
213
+ }
214
+ }
215
+ o .clientCapabilities = filteredCapabilities
196
216
}
197
217
}
198
218
199
219
// WithHTTPClient allows for a custom HTTP client to be set.
220
+ // if nil, the default HTTP client will be used.
200
221
func WithHTTPClient (httpClient ops.HTTPClient ) ClientOption {
201
222
return func (c * Client ) {
202
- c .httpClient = httpClient
223
+ if httpClient != nil {
224
+ c .httpClient = httpClient
225
+ }
203
226
}
204
227
}
205
228
@@ -323,28 +346,30 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
323
346
}
324
347
c .authParams .Scopes = []string {resource }
325
348
326
- // ignore cached access tokens when given claims
327
- if o .claims == "" {
328
- stResp , err := cacheManager .Read (ctx , c .authParams )
329
- if err != nil {
330
- return AuthResult {}, err
331
- }
332
- ar , err := base .AuthResultFromStorage (stResp )
333
- if err == nil {
349
+ stResp , err := cacheManager .Read (ctx , c .authParams )
350
+ if err != nil {
351
+ return AuthResult {}, err
352
+ }
353
+ ar , err := base .AuthResultFromStorage (stResp )
354
+ if err == nil {
355
+ if o .claims != "" {
356
+ // When the claims are set, we need to passon bad/old token
357
+ return c .getToken (ctx , resource , ar .AccessToken )
358
+ } else {
334
359
if ! stResp .AccessToken .RefreshOn .T .IsZero () && ! stResp .AccessToken .RefreshOn .T .After (now ()) && c .canRefresh .CompareAndSwap (false , true ) {
335
360
defer c .canRefresh .Store (false )
336
- if tr , er := c .getToken (ctx , resource ); er == nil {
361
+ if tr , er := c .getToken (ctx , resource , o . claims ); er == nil {
337
362
return tr , nil
338
363
}
339
364
}
340
365
ar .AccessToken , err = c .authParams .AuthnScheme .FormatAccessToken (ar .AccessToken )
341
366
return ar , err
342
367
}
343
368
}
344
- return c .getToken (ctx , resource )
369
+ return c .getToken (ctx , resource , "" )
345
370
}
346
371
347
- func (c Client ) getToken (ctx context.Context , resource string ) (AuthResult , error ) {
372
+ func (c Client ) getToken (ctx context.Context , resource string , badToken string ) (AuthResult , error ) {
348
373
switch c .source {
349
374
case AzureArc :
350
375
return c .acquireTokenForAzureArc (ctx , resource )
@@ -355,16 +380,16 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro
355
380
case DefaultToIMDS :
356
381
return c .acquireTokenForIMDS (ctx , resource )
357
382
case AppService :
358
- return c .acquireTokenForAppService (ctx , resource )
383
+ return c .acquireTokenForAppService (ctx , resource , badToken )
359
384
case ServiceFabric :
360
385
return c .acquireTokenForServiceFabric (ctx , resource )
361
386
default :
362
387
return AuthResult {}, fmt .Errorf ("unsupported source %q" , c .source )
363
388
}
364
389
}
365
390
366
- func (c Client ) acquireTokenForAppService (ctx context.Context , resource string ) (AuthResult , error ) {
367
- req , err := createAppServiceAuthRequest (ctx , c .miType , resource )
391
+ func (c Client ) acquireTokenForAppService (ctx context.Context , resource string , badToken string ) (AuthResult , error ) {
392
+ req , err := createAppServiceAuthRequest (ctx , c .miType , resource , badToken , c . clientCapabilities )
368
393
if err != nil {
369
394
return AuthResult {}, err
370
395
}
@@ -569,16 +594,27 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto
569
594
return r , err
570
595
}
571
596
572
- func createAppServiceAuthRequest (ctx context.Context , id ID , resource string ) (* http.Request , error ) {
597
+ func createAppServiceAuthRequest (ctx context.Context , id ID , resource string , badToken string , cc [] string ) (* http.Request , error ) {
573
598
identityEndpoint := os .Getenv (identityEndpointEnvVar )
574
599
req , err := http .NewRequestWithContext (ctx , http .MethodGet , identityEndpoint , nil )
575
600
if err != nil {
576
601
return nil , err
577
602
}
578
603
req .Header .Set ("X-IDENTITY-HEADER" , os .Getenv (identityHeaderEnvVar ))
604
+
579
605
q := req .URL .Query ()
580
606
q .Set ("api-version" , appServiceAPIVersion )
581
607
q .Set ("resource" , resource )
608
+
609
+ if badToken != "" {
610
+ hash := sha256 .Sum256 ([]byte (badToken ))
611
+ q .Set ("token_sha256_to_refresh" , hex .EncodeToString (hash [:]))
612
+ }
613
+
614
+ if len (cc ) > 0 {
615
+ q .Set ("xms_cc" , strings .Join (cc , "," ))
616
+ }
617
+
582
618
switch t := id .(type ) {
583
619
case UserAssignedClientID :
584
620
q .Set (miQueryParameterClientId , string (t ))
0 commit comments