Skip to content

Commit ee975d0

Browse files
authored
Add GetTokenOptions (Azure#2629)
1 parent c191d8e commit ee975d0

23 files changed

+251
-116
lines changed

sdk/core/azure_core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ sha2 = { workspace = true, optional = true }
2727
tokio = { workspace = true, optional = true }
2828
tracing.workspace = true
2929
typespec = { workspace = true, features = ["http", "json"] }
30-
typespec_client_core = { workspace = true, features = ["http", "json"] }
30+
typespec_client_core = { workspace = true, features = ["derive", "http", "json"] }
3131

3232
[build-dependencies]
3333
rustc_version.workspace = true

sdk/core/azure_core/src/credentials.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
use serde::{Deserialize, Serialize};
77
use std::{borrow::Cow, fmt::Debug};
8-
use typespec_client_core::date::OffsetDateTime;
8+
use typespec_client_core::{date::OffsetDateTime, fmt::SafeDebug};
99

1010
/// Default Azure authorization scope.
1111
pub static DEFAULT_SCOPE_SUFFIX: &str = "/.default";
@@ -87,10 +87,18 @@ impl AccessToken {
8787
}
8888
}
8989

90+
/// Options for getting a token from a [`TokenCredential`]
91+
#[derive(Clone, Default, SafeDebug)]
92+
pub struct TokenRequestOptions;
93+
9094
/// Represents a credential capable of providing an OAuth token.
9195
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
9296
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
9397
pub trait TokenCredential: Send + Sync + Debug {
94-
/// Gets a `AccessToken` for the specified resource
95-
async fn get_token(&self, scopes: &[&str]) -> crate::Result<AccessToken>;
98+
/// Gets an [`AccessToken`] for the specified scopes
99+
async fn get_token(
100+
&self,
101+
scopes: &[&str],
102+
options: Option<TokenRequestOptions>,
103+
) -> crate::Result<AccessToken>;
96104
}

sdk/core/azure_core/src/http/policies/bearer_token_policy.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl Policy for BearerTokenCredentialPolicy {
6767
drop(access_token);
6868
let mut access_token = self.access_token.write().await;
6969
if access_token.is_none() {
70-
*access_token = Some(self.credential.get_token(&self.scopes()).await?);
70+
*access_token = Some(self.credential.get_token(&self.scopes(), None).await?);
7171
}
7272
}
7373
Some(token) if should_refresh(&token.expires_on) => {
@@ -79,7 +79,7 @@ impl Policy for BearerTokenCredentialPolicy {
7979
// access_token shouldn't be None here, but check anyway to guarantee unwrap won't panic
8080
if access_token.is_none() || access_token.as_ref().unwrap().expires_on == expires_on
8181
{
82-
match self.credential.get_token(&self.scopes()).await {
82+
match self.credential.get_token(&self.scopes(), None).await {
8383
Ok(new_token) => {
8484
*access_token = Some(new_token);
8585
}
@@ -121,7 +121,7 @@ fn should_refresh(expires_on: &OffsetDateTime) -> bool {
121121
mod tests {
122122
use super::*;
123123
use crate::{
124-
credentials::{Secret, TokenCredential},
124+
credentials::{Secret, TokenCredential, TokenRequestOptions},
125125
http::{
126126
headers::{Headers, AUTHORIZATION},
127127
policies::Policy,
@@ -172,7 +172,11 @@ mod tests {
172172
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
173173
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
174174
impl TokenCredential for MockCredential {
175-
async fn get_token(&self, _scopes: &[&str]) -> Result<AccessToken> {
175+
async fn get_token(
176+
&self,
177+
_: &[&str],
178+
_: Option<TokenRequestOptions>,
179+
) -> Result<AccessToken> {
176180
let i = self.calls.fetch_add(1, Ordering::SeqCst);
177181
self.tokens
178182
.get(i)

sdk/core/azure_core_test/src/credentials.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
//! Credentials for live and recorded tests.
55
use azure_core::{
6-
credentials::{AccessToken, Secret, TokenCredential},
6+
credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions},
77
date::OffsetDateTime,
88
error::ErrorKind,
99
};
@@ -17,7 +17,11 @@ pub struct MockCredential;
1717
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
1818
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
1919
impl TokenCredential for MockCredential {
20-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
20+
async fn get_token(
21+
&self,
22+
scopes: &[&str],
23+
_: Option<TokenRequestOptions>,
24+
) -> azure_core::Result<AccessToken> {
2125
let token: Secret = format!("TEST TOKEN {}", scopes.join(" ")).into();
2226
let expires_on = OffsetDateTime::now_utc().saturating_add(
2327
Duration::from_secs(60 * 5).try_into().map_err(|err| {

sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async fn generate_authorization(
118118
let token = match auth_token {
119119
Credential::Token(token_credential) => {
120120
let token = token_credential
121-
.get_token(&[&scope_from_url(url)])
121+
.get_token(&[&scope_from_url(url)], None)
122122
.await?
123123
.token
124124
.secret()
@@ -146,7 +146,7 @@ mod tests {
146146
use std::sync::Arc;
147147

148148
use azure_core::{
149-
credentials::{AccessToken, TokenCredential},
149+
credentials::{AccessToken, TokenCredential, TokenRequestOptions},
150150
date,
151151
http::Method,
152152
};
@@ -168,7 +168,11 @@ mod tests {
168168
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
169169
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
170170
impl TokenCredential for TestTokenCredential {
171-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
171+
async fn get_token(
172+
&self,
173+
scopes: &[&str],
174+
_: Option<TokenRequestOptions>,
175+
) -> azure_core::Result<AccessToken> {
172176
let token = format!("{}+{}", self.0, scopes.join(","));
173177
Ok(AccessToken::new(
174178
token,

sdk/eventhubs/azure_messaging_eventhubs/src/common/connection_manager.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ impl ConnectionManager {
193193
debug!("Get Token.");
194194
let token = self
195195
.credential
196-
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE])
196+
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE], None)
197197
.await?;
198198

199199
debug!("Token for path {path} expires at {}", token.expires_on);
@@ -368,7 +368,7 @@ impl ConnectionManager {
368368

369369
let new_token = self
370370
.credential
371-
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE])
371+
.get_token(&[EVENTHUBS_AUTHORIZATION_SCOPE], None)
372372
.await?;
373373

374374
// Create an ephemeral session to host the authentication.
@@ -510,7 +510,7 @@ impl ConnectionManager {
510510
mod tests {
511511
use super::*;
512512
use async_trait::async_trait;
513-
use azure_core::{http::Url, Result};
513+
use azure_core::{credentials::TokenRequestOptions, http::Url, Result};
514514
use std::sync::Arc;
515515
use time::OffsetDateTime;
516516
use tracing::info;
@@ -551,7 +551,11 @@ mod tests {
551551

552552
#[async_trait]
553553
impl TokenCredential for MockTokenCredential {
554-
async fn get_token(&self, _scopes: &[&str]) -> Result<AccessToken> {
554+
async fn get_token(
555+
&self,
556+
_scopes: &[&str],
557+
_options: Option<TokenRequestOptions>,
558+
) -> Result<AccessToken> {
555559
// Simulate a token refresh by incrementing the token get count
556560
// and updating the token expiration time
557561
{

sdk/identity/azure_identity/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ impl ClientAssertion for VmClientAssertion {
9393
async fn secret(&self) -> azure_core::Result<String> {
9494
Ok(self
9595
.credential
96-
.get_token(&[&self.scope])
96+
.get_token(&[&self.scope], None)
9797
.await?
9898
.token
9999
.secret()
@@ -116,7 +116,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
116116
)?;
117117

118118
let fic_scope = String::from("your-service-app.com/scope");
119-
let fic_token = client_assertion_credential.get_token(&[&fic_scope]).await?;
119+
let fic_token = client_assertion_credential.get_token(&[&fic_scope], None).await?;
120120
Ok(())
121121
}
122122

sdk/identity/azure_identity/examples/azure_cli_credentials.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
1313

1414
let credentials = AzureCliCredential::new(None)?;
1515
let res = credentials
16-
.get_token(&["https://management.azure.com/.default"])
16+
.get_token(&["https://management.azure.com/.default"], None)
1717
.await?;
1818
eprintln!("Azure CLI response == {res:?}");
1919

sdk/identity/azure_identity/examples/default_credentials.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
1717
let url = url::Url::parse(&format!("https://management.azure.com/subscriptions/{subscription_id}/providers/Microsoft.Storage/storageAccounts?api-version=2019-06-01"))?;
1818

1919
let access_token = credential
20-
.get_token(&["https://management.azure.com/.default"])
20+
.get_token(&["https://management.azure.com/.default"], None)
2121
.await?;
2222

2323
let response = reqwest::Client::new()

sdk/identity/azure_identity/examples/specific_credential.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Licensed under the MIT License.
33

44
use azure_core::{
5-
credentials::{AccessToken, TokenCredential},
5+
credentials::{AccessToken, TokenCredential, TokenRequestOptions},
66
error::{ErrorKind, ResultExt},
77
Error,
88
};
@@ -25,7 +25,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
2525
let url = url::Url::parse(&format!("https://management.azure.com/subscriptions/{subscription_id}/providers/Microsoft.Storage/storageAccounts?api-version=2019-06-01"))?;
2626

2727
let access_token = credential
28-
.get_token(&["https://management.azure.com/.default"])
28+
.get_token(&["https://management.azure.com/.default"], None)
2929
.await?;
3030

3131
let response = reqwest::Client::new()
@@ -63,15 +63,21 @@ enum SpecificAzureCredentialKind {
6363
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
6464
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
6565
impl TokenCredential for SpecificAzureCredentialKind {
66-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
66+
async fn get_token(
67+
&self,
68+
scopes: &[&str],
69+
options: Option<TokenRequestOptions>,
70+
) -> azure_core::Result<AccessToken> {
6771
match self {
6872
#[cfg(not(target_arch = "wasm32"))]
69-
SpecificAzureCredentialKind::AzureCli(credential) => credential.get_token(scopes).await,
73+
SpecificAzureCredentialKind::AzureCli(credential) => {
74+
credential.get_token(scopes, options).await
75+
}
7076
SpecificAzureCredentialKind::ManagedIdentity(credential) => {
71-
credential.get_token(scopes).await
77+
credential.get_token(scopes, options).await
7278
}
7379
SpecificAzureCredentialKind::WorkloadIdentity(credential) => {
74-
credential.get_token(scopes).await
80+
credential.get_token(scopes, options).await
7581
}
7682
}
7783
}
@@ -133,7 +139,11 @@ impl SpecificAzureCredential {
133139
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
134140
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
135141
impl TokenCredential for SpecificAzureCredential {
136-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
137-
self.source.get_token(scopes).await
142+
async fn get_token(
143+
&self,
144+
scopes: &[&str],
145+
options: Option<TokenRequestOptions>,
146+
) -> azure_core::Result<AccessToken> {
147+
self.source.get_token(scopes, options).await
138148
}
139149
}

0 commit comments

Comments
 (0)