Skip to content

Commit c3c36f5

Browse files
authored
Adds new AAD scope value instead of account scope (Azure#3246)
As part of this PR we are adding a new AAD audience scope value "https://cosmos.azure.com/.default". We will be using this instead of account scope value in rust sdk. This new scope value will be eventually used across different clouds and regions. It is already accepted in public cloud across stage, canary and prod.
1 parent 5a4d5ea commit c3c36f5

File tree

1 file changed

+70
-35
lines changed

1 file changed

+70
-35
lines changed

sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use azure_core::{
1414
headers::{HeaderValue, AUTHORIZATION, MS_DATE, VERSION},
1515
policies::{Policy, PolicyResult},
1616
request::Request,
17-
Context, Url,
17+
Context,
1818
},
1919
time::{self, OffsetDateTime},
2020
};
@@ -26,6 +26,7 @@ use crate::{pipeline::signature_target::SignatureTarget, resource_context::Resou
2626
use crate::utils::url_encode;
2727

2828
const AZURE_VERSION: &str = "2020-07-15";
29+
const COSMOS_AAD_SCOPE: &str = "https://cosmos.azure.com/.default";
2930

3031
#[derive(Debug, Clone)]
3132
enum Credential {
@@ -82,7 +83,6 @@ impl Policy for AuthorizationPolicy {
8283

8384
let auth = generate_authorization(
8485
&self.credential,
85-
request.url(),
8686
SignatureTarget::new(request.method(), resource_link, &date_string),
8787
)
8888
.await?;
@@ -110,15 +110,13 @@ impl Policy for AuthorizationPolicy {
110110
/// NOTE: Resource tokens are not yet supported.
111111
async fn generate_authorization(
112112
auth_token: &Credential,
113-
url: &Url,
114-
115113
// Unused unless feature="key_auth", but I don't want to mess with excluding it since it makes call sites more complicated
116114
#[allow(unused_variables)] signature_target: SignatureTarget<'_>,
117115
) -> azure_core::Result<String> {
118116
let token = match auth_token {
119117
Credential::Token(token_credential) => {
120118
let token = token_credential
121-
.get_token(&[&scope_from_url(url)], None)
119+
.get_token(&[COSMOS_AAD_SCOPE], None)
122120
.await?
123121
.token
124122
.secret()
@@ -133,14 +131,6 @@ async fn generate_authorization(
133131
Ok(url_encode(token))
134132
}
135133

136-
/// This function generates the scope string from the passed url. The scope string is used to
137-
/// request the AAD token.
138-
fn scope_from_url(url: &Url) -> String {
139-
let scheme = url.scheme();
140-
let hostname = url.host_str().unwrap();
141-
format!("{scheme}://{hostname}/.default")
142-
}
143-
144134
#[cfg(test)]
145135
mod tests {
146136
use std::sync::Arc;
@@ -150,11 +140,10 @@ mod tests {
150140
http::Method,
151141
time::{Duration, OffsetDateTime},
152142
};
153-
use url::Url;
154143

155144
use crate::{
156145
pipeline::{
157-
authorization_policy::{generate_authorization, scope_from_url, Credential},
146+
authorization_policy::{generate_authorization, Credential, COSMOS_AAD_SCOPE},
158147
signature_target::SignatureTarget,
159148
},
160149
resource_context::{ResourceLink, ResourceType},
@@ -188,12 +177,8 @@ mod tests {
188177
let cred = Arc::new(TestTokenCredential("test_token".to_string()));
189178
let auth_token = Credential::Token(cred);
190179

191-
// Use a fake URL since the actual endpoint URL is not important for this test
192-
let url = Url::parse("https://test_account.example.com/dbs/ToDoList").unwrap();
193-
194180
let ret = generate_authorization(
195181
&auth_token,
196-
&url,
197182
SignatureTarget::new(
198183
Method::Get,
199184
&ResourceLink::root(ResourceType::Databases).item("ToDoList"),
@@ -203,10 +188,8 @@ mod tests {
203188
.await
204189
.unwrap();
205190

206-
let expected: String = url_encode(
207-
b"type=aad&ver=1.0&sig=test_token+https://test_account.example.com/.default",
208-
);
209-
191+
let expected: String =
192+
url_encode(format!("type=aad&ver=1.0&sig=test_token+{}", COSMOS_AAD_SCOPE).as_bytes());
210193
assert_eq!(ret, expected);
211194
}
212195

@@ -221,12 +204,8 @@ mod tests {
221204
"8F8xXXOptJxkblM1DBXW7a6NMI5oE8NnwPGYBmwxLCKfejOK7B7yhcCHMGvN3PBrlMLIOeol1Hv9RCdzAZR5sg==".into(),
222205
);
223206

224-
// Use a fake URL since the actual endpoint URL is not important for this test
225-
let url = Url::parse("https://test_account.example.com/dbs/ToDoList").unwrap();
226-
227207
let ret = generate_authorization(
228208
&auth_token,
229-
&url,
230209
SignatureTarget::new(
231210
Method::Get,
232211
&ResourceLink::root(ResourceType::Databases)
@@ -256,12 +235,8 @@ mod tests {
256235
"dsZQi3KtZmCv1ljt3VNWNm7sQUF1y5rJfC6kv5JiwvW0EndXdDku/dkKBp8/ufDToSxL".into(),
257236
);
258237

259-
// Use a fake URL since the actual endpoint URL is not important for this test
260-
let url = Url::parse("https://test_account.example.com/dbs/ToDoList").unwrap();
261-
262238
let ret = generate_authorization(
263239
&auth_token,
264-
&url,
265240
SignatureTarget::new(
266241
Method::Get,
267242
&ResourceLink::root(ResourceType::Databases).item("ToDoList"),
@@ -277,9 +252,69 @@ mod tests {
277252
assert_eq!(ret, expected);
278253
}
279254

280-
#[test]
281-
fn scope_from_url_extracts_correct_scope() {
282-
let scope = scope_from_url(&Url::parse("https://example.com/dbs/test_db/colls").unwrap());
283-
assert_eq!(scope, "https://example.com/.default");
255+
/// Tests that AAD authentication explicitly uses the constant scope value.
256+
#[tokio::test]
257+
async fn aad_token_uses_constant_scope() {
258+
use std::sync::Mutex;
259+
260+
// Mock credential that captures the exact scopes passed to get_token
261+
#[derive(Debug)]
262+
struct ScopeCapturingCredential {
263+
captured_scopes: Arc<Mutex<Vec<Vec<String>>>>,
264+
}
265+
266+
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
267+
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
268+
impl TokenCredential for ScopeCapturingCredential {
269+
async fn get_token(
270+
&self,
271+
scopes: &[&str],
272+
_: Option<TokenRequestOptions<'_>>,
273+
) -> azure_core::Result<AccessToken> {
274+
self.captured_scopes
275+
.lock()
276+
.unwrap()
277+
.push(scopes.iter().map(|s| s.to_string()).collect());
278+
279+
Ok(AccessToken::new(
280+
"mock_token".to_string(),
281+
OffsetDateTime::now_utc().saturating_add(Duration::minutes(5)),
282+
))
283+
}
284+
}
285+
286+
let captured_scopes = Arc::new(Mutex::new(Vec::new()));
287+
let cred = Arc::new(ScopeCapturingCredential {
288+
captured_scopes: captured_scopes.clone(),
289+
});
290+
let auth_token = Credential::Token(cred);
291+
292+
let time_nonce =
293+
azure_core::time::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap();
294+
let date_string = azure_core::time::to_rfc7231(&time_nonce).to_lowercase();
295+
296+
let _result = generate_authorization(
297+
&auth_token,
298+
SignatureTarget::new(
299+
Method::Get,
300+
&ResourceLink::root(ResourceType::Databases).item("TestDB"),
301+
&date_string,
302+
),
303+
)
304+
.await
305+
.unwrap();
306+
307+
// Verifies that get_token was called exactly once with the constant scope
308+
let scopes = captured_scopes.lock().unwrap();
309+
assert_eq!(scopes.len(), 1, "get_token should be called exactly once");
310+
assert_eq!(
311+
scopes[0].len(),
312+
1,
313+
"get_token should be called with exactly one scope"
314+
);
315+
assert_eq!(
316+
scopes[0][0], COSMOS_AAD_SCOPE,
317+
"get_token should be called with COSMOS_AAD_SCOPE constant"
318+
);
284319
}
285320
}

0 commit comments

Comments
 (0)