@@ -5,6 +5,7 @@ package runtime
5
5
6
6
import (
7
7
"context"
8
+ "encoding/base64"
8
9
"fmt"
9
10
"net/http"
10
11
"strings"
@@ -14,9 +15,13 @@ import (
14
15
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
15
16
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
16
17
azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
18
+ azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
17
19
"github.com/Azure/azure-sdk-for-go/sdk/internal/temporal"
18
20
)
19
21
22
+ const headerAuxiliaryAuthorization = "x-ms-authorization-auxiliary"
23
+
24
+ // acquiringResourceState holds data for an auxiliary token request
20
25
type acquiringResourceState struct {
21
26
ctx context.Context
22
27
p * BearerTokenPolicy
@@ -26,7 +31,10 @@ type acquiringResourceState struct {
26
31
// acquire acquires or updates the resource; only one
27
32
// thread/goroutine at a time ever calls this function
28
33
func acquire (state acquiringResourceState ) (newResource azcore.AccessToken , newExpiration time.Time , err error ) {
29
- tk , err := state .p .cred .GetToken (state .ctx , azpolicy.TokenRequestOptions {Scopes : state .p .options .Scopes })
34
+ tk , err := state .p .cred .GetToken (state .ctx , azpolicy.TokenRequestOptions {
35
+ Scopes : state .p .scopes ,
36
+ TenantID : state .tenant ,
37
+ })
30
38
if err != nil {
31
39
return azcore.AccessToken {}, time.Time {}, err
32
40
}
@@ -35,13 +43,10 @@ func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newE
35
43
36
44
// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential.
37
45
type BearerTokenPolicy struct {
38
- // mainResource is the resource to be retreived using the tenant specified in the credential
39
- mainResource * temporal.Resource [azcore.AccessToken , acquiringResourceState ]
40
- // auxResources are additional resources that are required for cross-tenant applications
41
46
auxResources map [string ]* temporal.Resource [azcore.AccessToken , acquiringResourceState ]
42
- // the following fields are read-only
43
- cred azcore.TokenCredential
44
- options armpolicy. BearerTokenOptions
47
+ btp * azruntime. BearerTokenPolicy
48
+ cred azcore.TokenCredential
49
+ scopes [] string
45
50
}
46
51
47
52
// NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens.
@@ -51,36 +56,90 @@ func NewBearerTokenPolicy(cred azcore.TokenCredential, opts *armpolicy.BearerTok
51
56
if opts == nil {
52
57
opts = & armpolicy.BearerTokenOptions {}
53
58
}
54
- p := & BearerTokenPolicy {
55
- cred : cred ,
56
- options : * opts ,
57
- mainResource : temporal .NewResource (acquire ),
59
+ p := & BearerTokenPolicy {cred : cred }
60
+ p . auxResources = make ( map [ string ] * temporal. Resource [azcore. AccessToken , acquiringResourceState ], len ( opts . AuxiliaryTenants ))
61
+ for _ , t := range opts . AuxiliaryTenants {
62
+ p . auxResources [ t ] = temporal .NewResource (acquire )
58
63
}
64
+ p .scopes = make ([]string , len (opts .Scopes ))
65
+ copy (p .scopes , opts .Scopes )
66
+ p .btp = azruntime .NewBearerTokenPolicy (cred , opts .Scopes , & azpolicy.BearerTokenOptions {
67
+ AuthorizationHandler : azpolicy.AuthorizationHandler {
68
+ OnChallenge : p .onChallenge ,
69
+ OnRequest : p .onRequest ,
70
+ },
71
+ })
59
72
return p
60
73
}
61
74
62
- // Do authorizes a request with a bearer token
63
- func (b * BearerTokenPolicy ) Do (req * azpolicy.Request ) (* http.Response , error ) {
75
+ func (b * BearerTokenPolicy ) onChallenge (req * azpolicy.Request , res * http.Response , authNZ func (azpolicy.TokenRequestOptions ) error ) error {
76
+ challenge := res .Header .Get (shared .HeaderWWWAuthenticate )
77
+ claims , err := parseChallenge (challenge )
78
+ if err != nil {
79
+ // the challenge contains claims we can't parse
80
+ return err
81
+ } else if claims != "" {
82
+ // request a new token having the specified claims, send the request again
83
+ return authNZ (azpolicy.TokenRequestOptions {Claims : claims , Scopes : b .scopes })
84
+ }
85
+ // auth challenge didn't include claims, so this is a simple authorization failure
86
+ return azruntime .NewResponseError (res )
87
+ }
88
+
89
+ // onRequest authorizes requests with one or more bearer tokens
90
+ func (b * BearerTokenPolicy ) onRequest (req * azpolicy.Request , authNZ func (azpolicy.TokenRequestOptions ) error ) error {
91
+ // authorize the request with a token for the primary tenant
92
+ err := authNZ (azpolicy.TokenRequestOptions {Scopes : b .scopes })
93
+ if err != nil || len (b .auxResources ) == 0 {
94
+ return err
95
+ }
96
+ // add tokens for auxiliary tenants
64
97
as := acquiringResourceState {
65
98
ctx : req .Raw ().Context (),
66
99
p : b ,
67
100
}
68
- tk , err := b .mainResource .Get (as )
69
- if err != nil {
70
- return nil , err
71
- }
72
- req .Raw ().Header .Set (shared .HeaderAuthorization , shared .BearerTokenPrefix + tk .Token )
73
- auxTokens := []string {}
101
+ auxTokens := make ([]string , 0 , len (b .auxResources ))
74
102
for tenant , er := range b .auxResources {
75
103
as .tenant = tenant
76
104
auxTk , err := er .Get (as )
77
105
if err != nil {
78
- return nil , err
106
+ return err
79
107
}
80
108
auxTokens = append (auxTokens , fmt .Sprintf ("%s%s" , shared .BearerTokenPrefix , auxTk .Token ))
81
109
}
82
- if len (auxTokens ) > 0 {
83
- req .Raw ().Header .Set (shared .HeaderAuxiliaryAuthorization , strings .Join (auxTokens , ", " ))
110
+ req .Raw ().Header .Set (headerAuxiliaryAuthorization , strings .Join (auxTokens , ", " ))
111
+ return nil
112
+ }
113
+
114
+ // Do authorizes a request with a bearer token
115
+ func (b * BearerTokenPolicy ) Do (req * azpolicy.Request ) (* http.Response , error ) {
116
+ return b .btp .Do (req )
117
+ }
118
+
119
+ // parseChallenge parses claims from an authentication challenge issued by ARM so a client can request a token
120
+ // that will satisfy conditional access policies. It returns a non-nil error when the given value contains
121
+ // claims it can't parse. If the value contains no claims, it returns an empty string and a nil error.
122
+ func parseChallenge (wwwAuthenticate string ) (string , error ) {
123
+ claims := ""
124
+ var err error
125
+ for _ , param := range strings .Split (wwwAuthenticate , "," ) {
126
+ if _ , after , found := strings .Cut (param , "claims=" ); found {
127
+ if claims != "" {
128
+ // The header contains multiple challenges, at least two of which specify claims. The specs allow this
129
+ // but it's unclear what a client should do in this case and there's as yet no concrete example of it.
130
+ err = fmt .Errorf ("found multiple claims challenges in %q" , wwwAuthenticate )
131
+ break
132
+ }
133
+ // trim stuff that would get an error from RawURLEncoding; claims may or may not be padded
134
+ claims = strings .Trim (after , `\"=` )
135
+ // we don't return this error because it's something unhelpful like "illegal base64 data at input byte 42"
136
+ if b , decErr := base64 .RawURLEncoding .DecodeString (claims ); decErr == nil {
137
+ claims = string (b )
138
+ } else {
139
+ err = fmt .Errorf ("failed to parse claims from %q" , wwwAuthenticate )
140
+ break
141
+ }
142
+ }
84
143
}
85
- return req . Next ()
144
+ return claims , err
86
145
}
0 commit comments