Skip to content

Commit 776adf5

Browse files
authored
Add ChainedTokenCredential (Azure#2309)
* add ChainedTokenCredential * move to root (per @heaths) * updates from PR feedback * removed builder per guidelines * renamed the file per request * added ChainedTokenCredentialOptions * align with expectations from Azure#2306. Note, this does not use ClientAssertionCredentialOptions yet, as the aforementioned PR must be merged first. * added retry_sources to model the go impl * add clearing `successful_credential` on clearing token cache * add `impl From<&[Arc<dyn TokenCredential>]> for ChainedTokenCredential` * add integration tests to validate retry works as expected --------- Co-authored-by: Brian Caswell <[email protected]>
1 parent 3130fe4 commit 776adf5

File tree

3 files changed

+259
-1
lines changed

3 files changed

+259
-1
lines changed
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
use crate::credentials::cache::TokenCache;
5+
use crate::TokenCredentialOptions;
6+
use async_lock::RwLock;
7+
use azure_core::{
8+
credentials::{AccessToken, TokenCredential},
9+
error::{Error, ErrorKind},
10+
};
11+
use std::sync::Arc;
12+
13+
#[derive(Debug, Default)]
14+
/// ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential.
15+
pub struct ChainedTokenCredentialOptions {
16+
pub retry_sources: bool,
17+
pub credential_options: TokenCredentialOptions,
18+
}
19+
20+
// TODO: Should probably remove this once we consolidate and unify credentials.
21+
impl From<TokenCredentialOptions> for ChainedTokenCredentialOptions {
22+
fn from(credential_options: TokenCredentialOptions) -> Self {
23+
Self {
24+
retry_sources: Default::default(),
25+
credential_options,
26+
}
27+
}
28+
}
29+
30+
/// Provides a user-configurable `TokenCredential` authentication flow for applications that will be deployed to Azure.
31+
///
32+
/// The credential types are tried in the order specified by the user.
33+
#[derive(Debug)]
34+
pub struct ChainedTokenCredential {
35+
#[allow(dead_code)]
36+
options: ChainedTokenCredentialOptions,
37+
sources: Vec<Arc<dyn TokenCredential>>,
38+
cache: TokenCache,
39+
successful_credential: RwLock<Option<Arc<dyn TokenCredential>>>,
40+
}
41+
42+
impl ChainedTokenCredential {
43+
/// Create a `ChainedTokenCredential` with options.
44+
pub fn new(options: Option<ChainedTokenCredentialOptions>) -> Self {
45+
Self {
46+
options: options.unwrap_or_default(),
47+
sources: Vec::new(),
48+
cache: TokenCache::new(),
49+
successful_credential: RwLock::new(None),
50+
}
51+
}
52+
53+
/// Add a credential source to the chain.
54+
pub fn add_source(&mut self, source: Arc<dyn TokenCredential>) {
55+
self.sources.push(source);
56+
}
57+
58+
async fn get_token_impl(
59+
&self,
60+
scopes: &[&str],
61+
) -> azure_core::Result<(Arc<dyn TokenCredential>, AccessToken)> {
62+
let mut errors = Vec::new();
63+
for source in &self.sources {
64+
let token_res = source.get_token(scopes).await;
65+
66+
match token_res {
67+
Ok(token) => return Ok((source.clone(), token)),
68+
Err(error) => errors.push(error),
69+
}
70+
}
71+
Err(Error::with_message(ErrorKind::Credential, || {
72+
format!(
73+
"Multiple errors were encountered while attempting to authenticate:\n{}",
74+
format_aggregate_error(&errors)
75+
)
76+
}))
77+
}
78+
79+
/// Try to fetch a token using each of the credential sources until one succeeds
80+
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
81+
if !self.options.retry_sources {
82+
if let Some(entry) = self.successful_credential.read().await.as_ref() {
83+
return entry.get_token(scopes).await;
84+
}
85+
let mut lock = self.successful_credential.write().await;
86+
// if after getting the write lock, we find that another thread has already found a credential, use that.
87+
if let Some(entry) = lock.as_ref() {
88+
return entry.get_token(scopes).await;
89+
}
90+
let (entry, token) = self.get_token_impl(scopes).await?;
91+
*lock = Some(entry);
92+
Ok(token)
93+
} else {
94+
// if we are retrying sources, we don't need to cache the successful credential
95+
Ok(self.get_token_impl(scopes).await?.1)
96+
}
97+
}
98+
}
99+
100+
impl From<&[Arc<dyn TokenCredential>]> for ChainedTokenCredential {
101+
fn from(credential_options: &[Arc<dyn TokenCredential>]) -> Self {
102+
Self {
103+
options: ChainedTokenCredentialOptions::default(),
104+
sources: credential_options.to_vec(),
105+
cache: TokenCache::new(),
106+
successful_credential: RwLock::new(None),
107+
}
108+
}
109+
}
110+
111+
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
112+
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
113+
impl TokenCredential for ChainedTokenCredential {
114+
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
115+
self.cache.get_token(scopes, self.get_token(scopes)).await
116+
}
117+
118+
/// Clear the credential's cache.
119+
async fn clear_cache(&self) -> azure_core::Result<()> {
120+
// clear the internal cache as well as each of the underlying providers
121+
self.cache.clear().await?;
122+
123+
for source in &self.sources {
124+
source.clear_cache().await?;
125+
}
126+
127+
// clear the successful credential if we are clearing the token cache
128+
self.successful_credential.write().await.take();
129+
130+
Ok(())
131+
}
132+
}
133+
134+
fn format_aggregate_error(errors: &[Error]) -> String {
135+
use std::error::Error;
136+
errors
137+
.iter()
138+
.map(|e| {
139+
let mut current: Option<&dyn Error> = Some(e);
140+
let mut stack = vec![];
141+
while let Some(err) = current.take() {
142+
stack.push(err.to_string());
143+
current = err.source();
144+
}
145+
stack.join(" - ")
146+
})
147+
.collect::<Vec<String>>()
148+
.join("\n")
149+
}
150+
151+
#[cfg(test)]
152+
mod tests {
153+
use super::*;
154+
use async_lock::Mutex;
155+
use azure_core::credentials::{AccessToken, TokenCredential};
156+
use azure_core_test::credentials::MockCredential;
157+
158+
/// `TokenFailure` is a mock credential that always fails to get a token.
159+
#[derive(Debug)]
160+
struct TokenFailure {
161+
counter: Mutex<u32>,
162+
}
163+
164+
impl TokenFailure {
165+
fn new() -> Self {
166+
Self {
167+
counter: Mutex::new(0),
168+
}
169+
}
170+
171+
async fn get_counter(&self) -> u32 {
172+
let count = self.counter.lock().await;
173+
*count
174+
}
175+
}
176+
177+
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
178+
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
179+
impl TokenCredential for TokenFailure {
180+
async fn get_token(&self, _scopes: &[&str]) -> azure_core::Result<AccessToken> {
181+
let mut count = self.counter.lock().await;
182+
*count += 1;
183+
Err(Error::message(ErrorKind::Credential, "failed to get token"))
184+
}
185+
186+
async fn clear_cache(&self) -> azure_core::Result<()> {
187+
Ok(())
188+
}
189+
}
190+
191+
#[tokio::test]
192+
async fn test_basic() -> azure_core::Result<()> {
193+
let providers: Vec<Arc<dyn TokenCredential>> = vec![Arc::new(MockCredential {})];
194+
let credentials = ChainedTokenCredential::from(providers.as_slice());
195+
let scopes = ["https://management.azure.com/.default"];
196+
let token = credentials.get_token(&scopes).await?;
197+
assert_eq!(
198+
token.token.secret(),
199+
"TEST TOKEN https://management.azure.com/.default"
200+
);
201+
202+
Ok(())
203+
}
204+
205+
#[tokio::test]
206+
async fn test_with_retry() -> azure_core::Result<()> {
207+
let token_failure = Arc::new(TokenFailure::new());
208+
let mut chained_credential =
209+
ChainedTokenCredential::new(Some(ChainedTokenCredentialOptions {
210+
retry_sources: true,
211+
..Default::default()
212+
}));
213+
chained_credential.add_source(token_failure.clone());
214+
chained_credential.add_source(Arc::new(MockCredential {}));
215+
216+
let scopes = ["https://management.azure.com/.default"];
217+
let token = chained_credential.get_token(&scopes).await?;
218+
assert_eq!(
219+
token.token.secret(),
220+
"TEST TOKEN https://management.azure.com/.default"
221+
);
222+
let scopes = ["https://management.azure.com/.default"];
223+
let token = chained_credential.get_token(&scopes).await?;
224+
assert_eq!(
225+
token.token.secret(),
226+
"TEST TOKEN https://management.azure.com/.default"
227+
);
228+
229+
assert_eq!(token_failure.get_counter().await, 2);
230+
Ok(())
231+
}
232+
233+
#[tokio::test]
234+
async fn test_without_retry() -> azure_core::Result<()> {
235+
let token_failure = Arc::new(TokenFailure::new());
236+
let mut chained_credential = ChainedTokenCredential::new(None);
237+
chained_credential.add_source(token_failure.clone());
238+
chained_credential.add_source(Arc::new(MockCredential {}));
239+
240+
let scopes = ["https://management.azure.com/.default"];
241+
let token = chained_credential.get_token(&scopes).await?;
242+
assert_eq!(
243+
token.token.secret(),
244+
"TEST TOKEN https://management.azure.com/.default"
245+
);
246+
let scopes = ["https://management.azure.com/.default"];
247+
let token = chained_credential.get_token(&scopes).await?;
248+
assert_eq!(
249+
token.token.secret(),
250+
"TEST TOKEN https://management.azure.com/.default"
251+
);
252+
253+
assert_eq!(token_failure.get_counter().await, 1);
254+
Ok(())
255+
}
256+
}

sdk/identity/azure_identity/src/credentials/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
mod app_service_managed_identity_credential;
1212
#[cfg(not(target_arch = "wasm32"))]
1313
mod azure_cli_credentials;
14-
mod cache;
14+
pub(crate) mod cache;
1515
mod client_assertion_credentials;
1616
#[cfg(feature = "client_certificate")]
1717
mod client_certificate_credentials;

sdk/identity/azure_identity/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
mod authorization_code_flow;
77
mod azure_pipelines_credential;
8+
mod chained_token_credential;
89
mod credentials;
910
mod env;
1011
mod federated_credentials_flow;
@@ -14,6 +15,7 @@ mod timeout;
1415

1516
use azure_core::{error::ErrorKind, Error, Result};
1617
pub use azure_pipelines_credential::*;
18+
pub use chained_token_credential::*;
1719
pub use credentials::*;
1820
use std::borrow::Cow;
1921

0 commit comments

Comments
 (0)