77package internal
88
99import (
10+ "bytes"
1011 "context"
1112 "fmt"
1213 "net/http"
@@ -17,11 +18,29 @@ import (
1718 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
1819 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
1920 "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
21+ "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
2022 "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
2123 "github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
2224 "github.com/stretchr/testify/require"
2325)
2426
27+ const (
28+ challengedToken = "needs more claims"
29+ claimsToken = "all the claims"
30+ kvChallenge = `Bearer authorization="https://login.microsoftonline.com/tenant", resource="https://vault.azure.net"`
31+ caeChallenge1 = `Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="dGVzdGluZzE="`
32+ caeChallenge2 = `Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="dGVzdGluZzI="`
33+ )
34+
35+ // requireToken is a mock.Response predicate that checks a request for the expected token
36+ var requireToken = func (t * testing.T , want string ) func (req * http.Request ) bool {
37+ return func (r * http.Request ) bool {
38+ _ , actual , _ := strings .Cut (r .Header .Get ("Authorization" ), " " )
39+ require .Equal (t , want , actual )
40+ return true
41+ }
42+ }
43+
2544type credentialFunc func (context.Context , policy.TokenRequestOptions ) (azcore.AccessToken , error )
2645
2746func (cf credentialFunc ) GetToken (ctx context.Context , options policy.TokenRequestOptions ) (azcore.AccessToken , error ) {
@@ -100,6 +119,183 @@ func TestChallengePolicy(t *testing.T) {
100119 }
101120}
102121
122+ func TestChallengePolicy_CAE (t * testing.T ) {
123+ srv , close := mock .NewServer (mock .WithTransformAllRequestsToTestServerUrl ())
124+ defer close ()
125+ srv .AppendResponse (
126+ mock .WithHeader ("WWW-Authenticate" , kvChallenge ),
127+ mock .WithStatusCode (401 ),
128+ mock .WithPredicate (requireToken (t , "" )),
129+ )
130+ srv .AppendResponse () // when a response's predicate returns true, srv pops the following one
131+ srv .AppendResponse ()
132+
133+ srv .AppendResponse (
134+ mock .WithHeader ("WWW-Authenticate" , caeChallenge1 ),
135+ mock .WithStatusCode (401 ),
136+ mock .WithPredicate (requireToken (t , challengedToken )),
137+ )
138+ srv .AppendResponse () // when a response's predicate returns true, srv pops the following one
139+ srv .AppendResponse (
140+ mock .WithPredicate (requireToken (t , claimsToken )),
141+ )
142+ srv .AppendResponse () // when a response's predicate returns true, srv pops the following one
143+
144+ tkReqs := 0
145+ cred := credentialFunc (func (ctx context.Context , tro policy.TokenRequestOptions ) (azcore.AccessToken , error ) {
146+ require .True (t , tro .EnableCAE )
147+ tkReqs += 1
148+ tk := challengedToken
149+ switch tkReqs {
150+ case 1 :
151+ require .Empty (t , tro .Claims )
152+ case 2 :
153+ tk = claimsToken
154+ require .Equal (t , "testing1" , tro .Claims )
155+ default :
156+ t .Fatal ("unexpected token request" )
157+ }
158+ return azcore.AccessToken {Token : tk , ExpiresOn : time .Now ().Add (time .Hour )}, nil
159+ })
160+ p := NewKeyVaultChallengePolicy (cred , nil )
161+ pl := runtime .NewPipeline ("" , "" ,
162+ runtime.PipelineOptions {PerRetry : []policy.Policy {p }},
163+ & policy.ClientOptions {Transport : srv },
164+ )
165+
166+ // req 1 kv then regular
167+ req , err := runtime .NewRequest (context .Background (), "POST" , "https://42.vault.azure.net" )
168+ require .NoError (t , err )
169+ err = req .SetBody (streaming .NopCloser (bytes .NewReader ([]byte ("test" ))), "text/plain" )
170+ require .NoError (t , err )
171+ res , err := pl .Do (req )
172+ require .NoError (t , err )
173+ require .Equal (t , 200 , res .StatusCode )
174+ require .Equal (t , 1 , tkReqs )
175+
176+ // req 2 cae
177+ req , err = runtime .NewRequest (context .Background (), "POST" , "https://42.vault.azure.net" )
178+ require .NoError (t , err )
179+ err = req .SetBody (streaming .NopCloser (bytes .NewReader ([]byte ("test2" ))), "text/plain" )
180+ require .NoError (t , err )
181+ res , err = pl .Do (req )
182+ require .NoError (t , err )
183+ require .Equal (t , 200 , res .StatusCode )
184+ require .Equal (t , 2 , tkReqs )
185+ }
186+
187+ func TestChallengePolicy_KVThenCAE (t * testing.T ) {
188+ srv , close := mock .NewServer (mock .WithTransformAllRequestsToTestServerUrl ())
189+ defer close ()
190+ srv .AppendResponse (
191+ mock .WithHeader ("WWW-Authenticate" , kvChallenge ),
192+ mock .WithStatusCode (401 ),
193+ mock .WithPredicate (requireToken (t , "" )),
194+ )
195+ srv .AppendResponse () // when a response's predicate returns true, srv pops the following one
196+ srv .AppendResponse (
197+ mock .WithHeader ("WWW-Authenticate" , caeChallenge1 ),
198+ mock .WithStatusCode (401 ),
199+ mock .WithPredicate (requireToken (t , challengedToken )),
200+ )
201+ srv .AppendResponse () // when a response's predicate returns true, srv pops the following one
202+ srv .AppendResponse (
203+ mock .WithPredicate (requireToken (t , claimsToken )),
204+ )
205+ srv .AppendResponse () // when a response's predicate returns true, srv pops the following one
206+
207+ tkReqs := 0
208+ cred := credentialFunc (func (ctx context.Context , tro policy.TokenRequestOptions ) (azcore.AccessToken , error ) {
209+ require .True (t , tro .EnableCAE )
210+ tkReqs += 1
211+ tk := challengedToken
212+ switch tkReqs {
213+ case 1 :
214+ require .Empty (t , tro .Claims )
215+ case 2 :
216+ tk = claimsToken
217+ require .Equal (t , "testing1" , tro .Claims )
218+ default :
219+ t .Fatal ("unexpected token request" )
220+ }
221+ return azcore.AccessToken {Token : tk , ExpiresOn : time .Now ().Add (time .Hour )}, nil
222+ })
223+ p := NewKeyVaultChallengePolicy (cred , nil )
224+ pl := runtime .NewPipeline ("" , "" ,
225+ runtime.PipelineOptions {PerRetry : []policy.Policy {p }},
226+ & policy.ClientOptions {Transport : srv },
227+ )
228+ req , err := runtime .NewRequest (context .Background (), "GET" , "https://42.vault.azure.net" )
229+ require .NoError (t , err )
230+ res , err := pl .Do (req )
231+ require .NoError (t , err )
232+ require .Equal (t , 200 , res .StatusCode )
233+ require .Equal (t , tkReqs , 2 )
234+ }
235+
236+ func TestChallengePolicy_TwoCAEChallenges (t * testing.T ) {
237+ srv , close := mock .NewServer (mock .WithTransformAllRequestsToTestServerUrl ())
238+ defer close ()
239+ srv .AppendResponse (
240+ mock .WithHeader ("WWW-Authenticate" , kvChallenge ),
241+ mock .WithStatusCode (401 ),
242+ mock .WithPredicate (requireToken (t , "" )),
243+ )
244+ srv .AppendResponse () // when a response's predicate returns true, srv pops the following one
245+ srv .AppendResponse ()
246+
247+ srv .AppendResponse (
248+ mock .WithHeader ("WWW-Authenticate" , caeChallenge1 ),
249+ mock .WithStatusCode (401 ),
250+ mock .WithPredicate (requireToken (t , challengedToken )),
251+ )
252+ srv .AppendResponse () // when a response's predicate returns true, srv pops the following one
253+ srv .AppendResponse (
254+ mock .WithHeader ("WWW-Authenticate" , caeChallenge2 ),
255+ mock .WithStatusCode (401 ),
256+ mock .WithPredicate (requireToken (t , claimsToken )),
257+ )
258+ srv .AppendResponse () // when a response's predicate returns true, srv pops the following one
259+ tkReqs := 0
260+ cred := credentialFunc (func (ctx context.Context , tro policy.TokenRequestOptions ) (azcore.AccessToken , error ) {
261+ require .True (t , tro .EnableCAE )
262+ tk := challengedToken
263+ tkReqs += 1
264+ switch tkReqs {
265+ case 1 :
266+ require .Empty (t , tro .Claims )
267+ case 2 :
268+ tk = claimsToken
269+ require .Equal (t , "testing1" , tro .Claims )
270+ default :
271+ t .Fatal ("unexpected token request" )
272+ }
273+ return azcore.AccessToken {Token : tk , ExpiresOn : time .Now ().Add (time .Hour )}, nil
274+ })
275+ p := NewKeyVaultChallengePolicy (cred , nil )
276+ pl := runtime .NewPipeline ("" , "" ,
277+ runtime.PipelineOptions {PerRetry : []policy.Policy {p }},
278+ & policy.ClientOptions {Transport : srv },
279+ )
280+
281+ // req 1 kv then regular
282+ req , err := runtime .NewRequest (context .Background (), "GET" , "https://42.vault.azure.net" )
283+ require .NoError (t , err )
284+ res , err := pl .Do (req )
285+ require .NoError (t , err )
286+ require .Equal (t , 200 , res .StatusCode )
287+ require .Equal (t , tkReqs , 1 )
288+
289+ // req 2 cae twice
290+ req , err = runtime .NewRequest (context .Background (), "GET" , "https://42.vault.azure.net" )
291+ require .NoError (t , err )
292+ res , err = pl .Do (req )
293+ require .NoError (t , err )
294+ require .Equal (t , 401 , res .StatusCode )
295+ require .Equal (t , caeChallenge2 , res .Header .Get ("WWW-Authenticate" ))
296+ require .Equal (t , tkReqs , 2 )
297+ }
298+
103299func TestParseTenant (t * testing.T ) {
104300 actual := parseTenant ("" )
105301 require .Empty (t , actual )
0 commit comments