@@ -14,7 +14,7 @@ use azure_core::{
1414 headers:: { HeaderValue , AUTHORIZATION , MS_DATE , VERSION } ,
1515 policies:: { Policy , PolicyResult } ,
1616 request:: Request ,
17- Context , Url ,
17+ Context ,
1818 } ,
1919 time:: { self , OffsetDateTime } ,
2020} ;
@@ -26,6 +26,7 @@ use crate::{pipeline::signature_target::SignatureTarget, resource_context::Resou
2626use crate :: utils:: url_encode;
2727
2828const AZURE_VERSION : & str = "2020-07-15" ;
29+ const COSMOS_AAD_SCOPE : & str = "https://cosmos.azure.com/.default" ;
2930
3031#[ derive( Debug , Clone ) ]
3132enum Credential {
@@ -82,7 +83,6 @@ impl Policy for AuthorizationPolicy {
8283
8384 let auth = generate_authorization (
8485 & self . credential ,
85- request. url ( ) ,
8686 SignatureTarget :: new ( request. method ( ) , resource_link, & date_string) ,
8787 )
8888 . await ?;
@@ -110,15 +110,13 @@ impl Policy for AuthorizationPolicy {
110110/// NOTE: Resource tokens are not yet supported.
111111async fn generate_authorization (
112112 auth_token : & Credential ,
113- url : & Url ,
114-
115113 // Unused unless feature="key_auth", but I don't want to mess with excluding it since it makes call sites more complicated
116114 #[ allow( unused_variables) ] signature_target : SignatureTarget < ' _ > ,
117115) -> azure_core:: Result < String > {
118116 let token = match auth_token {
119117 Credential :: Token ( token_credential) => {
120118 let token = token_credential
121- . get_token ( & [ & scope_from_url ( url ) ] , None )
119+ . get_token ( & [ COSMOS_AAD_SCOPE ] , None )
122120 . await ?
123121 . token
124122 . secret ( )
@@ -133,14 +131,6 @@ async fn generate_authorization(
133131 Ok ( url_encode ( token) )
134132}
135133
136- /// This function generates the scope string from the passed url. The scope string is used to
137- /// request the AAD token.
138- fn scope_from_url ( url : & Url ) -> String {
139- let scheme = url. scheme ( ) ;
140- let hostname = url. host_str ( ) . unwrap ( ) ;
141- format ! ( "{scheme}://{hostname}/.default" )
142- }
143-
144134#[ cfg( test) ]
145135mod tests {
146136 use std:: sync:: Arc ;
@@ -150,11 +140,10 @@ mod tests {
150140 http:: Method ,
151141 time:: { Duration , OffsetDateTime } ,
152142 } ;
153- use url:: Url ;
154143
155144 use crate :: {
156145 pipeline:: {
157- authorization_policy:: { generate_authorization, scope_from_url , Credential } ,
146+ authorization_policy:: { generate_authorization, Credential , COSMOS_AAD_SCOPE } ,
158147 signature_target:: SignatureTarget ,
159148 } ,
160149 resource_context:: { ResourceLink , ResourceType } ,
@@ -188,12 +177,8 @@ mod tests {
188177 let cred = Arc :: new ( TestTokenCredential ( "test_token" . to_string ( ) ) ) ;
189178 let auth_token = Credential :: Token ( cred) ;
190179
191- // Use a fake URL since the actual endpoint URL is not important for this test
192- let url = Url :: parse ( "https://test_account.example.com/dbs/ToDoList" ) . unwrap ( ) ;
193-
194180 let ret = generate_authorization (
195181 & auth_token,
196- & url,
197182 SignatureTarget :: new (
198183 Method :: Get ,
199184 & ResourceLink :: root ( ResourceType :: Databases ) . item ( "ToDoList" ) ,
@@ -203,10 +188,8 @@ mod tests {
203188 . await
204189 . unwrap ( ) ;
205190
206- let expected: String = url_encode (
207- b"type=aad&ver=1.0&sig=test_token+https://test_account.example.com/.default" ,
208- ) ;
209-
191+ let expected: String =
192+ url_encode ( format ! ( "type=aad&ver=1.0&sig=test_token+{}" , COSMOS_AAD_SCOPE ) . as_bytes ( ) ) ;
210193 assert_eq ! ( ret, expected) ;
211194 }
212195
@@ -221,12 +204,8 @@ mod tests {
221204 "8F8xXXOptJxkblM1DBXW7a6NMI5oE8NnwPGYBmwxLCKfejOK7B7yhcCHMGvN3PBrlMLIOeol1Hv9RCdzAZR5sg==" . into ( ) ,
222205 ) ;
223206
224- // Use a fake URL since the actual endpoint URL is not important for this test
225- let url = Url :: parse ( "https://test_account.example.com/dbs/ToDoList" ) . unwrap ( ) ;
226-
227207 let ret = generate_authorization (
228208 & auth_token,
229- & url,
230209 SignatureTarget :: new (
231210 Method :: Get ,
232211 & ResourceLink :: root ( ResourceType :: Databases )
@@ -256,12 +235,8 @@ mod tests {
256235 "dsZQi3KtZmCv1ljt3VNWNm7sQUF1y5rJfC6kv5JiwvW0EndXdDku/dkKBp8/ufDToSxL" . into ( ) ,
257236 ) ;
258237
259- // Use a fake URL since the actual endpoint URL is not important for this test
260- let url = Url :: parse ( "https://test_account.example.com/dbs/ToDoList" ) . unwrap ( ) ;
261-
262238 let ret = generate_authorization (
263239 & auth_token,
264- & url,
265240 SignatureTarget :: new (
266241 Method :: Get ,
267242 & ResourceLink :: root ( ResourceType :: Databases ) . item ( "ToDoList" ) ,
@@ -277,9 +252,69 @@ mod tests {
277252 assert_eq ! ( ret, expected) ;
278253 }
279254
280- #[ test]
281- fn scope_from_url_extracts_correct_scope ( ) {
282- let scope = scope_from_url ( & Url :: parse ( "https://example.com/dbs/test_db/colls" ) . unwrap ( ) ) ;
283- assert_eq ! ( scope, "https://example.com/.default" ) ;
255+ /// Tests that AAD authentication explicitly uses the constant scope value.
256+ #[ tokio:: test]
257+ async fn aad_token_uses_constant_scope ( ) {
258+ use std:: sync:: Mutex ;
259+
260+ // Mock credential that captures the exact scopes passed to get_token
261+ #[ derive( Debug ) ]
262+ struct ScopeCapturingCredential {
263+ captured_scopes : Arc < Mutex < Vec < Vec < String > > > > ,
264+ }
265+
266+ #[ cfg_attr( target_arch = "wasm32" , async_trait:: async_trait( ?Send ) ) ]
267+ #[ cfg_attr( not( target_arch = "wasm32" ) , async_trait:: async_trait) ]
268+ impl TokenCredential for ScopeCapturingCredential {
269+ async fn get_token (
270+ & self ,
271+ scopes : & [ & str ] ,
272+ _: Option < TokenRequestOptions < ' _ > > ,
273+ ) -> azure_core:: Result < AccessToken > {
274+ self . captured_scopes
275+ . lock ( )
276+ . unwrap ( )
277+ . push ( scopes. iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ) ;
278+
279+ Ok ( AccessToken :: new (
280+ "mock_token" . to_string ( ) ,
281+ OffsetDateTime :: now_utc ( ) . saturating_add ( Duration :: minutes ( 5 ) ) ,
282+ ) )
283+ }
284+ }
285+
286+ let captured_scopes = Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ;
287+ let cred = Arc :: new ( ScopeCapturingCredential {
288+ captured_scopes : captured_scopes. clone ( ) ,
289+ } ) ;
290+ let auth_token = Credential :: Token ( cred) ;
291+
292+ let time_nonce =
293+ azure_core:: time:: parse_rfc3339 ( "1900-01-01T01:00:00.000000000+00:00" ) . unwrap ( ) ;
294+ let date_string = azure_core:: time:: to_rfc7231 ( & time_nonce) . to_lowercase ( ) ;
295+
296+ let _result = generate_authorization (
297+ & auth_token,
298+ SignatureTarget :: new (
299+ Method :: Get ,
300+ & ResourceLink :: root ( ResourceType :: Databases ) . item ( "TestDB" ) ,
301+ & date_string,
302+ ) ,
303+ )
304+ . await
305+ . unwrap ( ) ;
306+
307+ // Verifies that get_token was called exactly once with the constant scope
308+ let scopes = captured_scopes. lock ( ) . unwrap ( ) ;
309+ assert_eq ! ( scopes. len( ) , 1 , "get_token should be called exactly once" ) ;
310+ assert_eq ! (
311+ scopes[ 0 ] . len( ) ,
312+ 1 ,
313+ "get_token should be called with exactly one scope"
314+ ) ;
315+ assert_eq ! (
316+ scopes[ 0 ] [ 0 ] , COSMOS_AAD_SCOPE ,
317+ "get_token should be called with COSMOS_AAD_SCOPE constant"
318+ ) ;
284319 }
285320}
0 commit comments