@@ -13,6 +13,7 @@ use async_lock::RwLock;
1313use async_trait:: async_trait;
1414use std:: sync:: Arc ;
1515use std:: time:: Duration ;
16+ use typespec_client_core:: date:: OffsetDateTime ;
1617use typespec_client_core:: http:: { Context , Request } ;
1718
1819/// Authentication policy for a bearer token.
@@ -23,9 +24,6 @@ pub struct BearerTokenCredentialPolicy {
2324 access_token : Arc < RwLock < Option < AccessToken > > > ,
2425}
2526
26- /// Default timeout in seconds before refreshing a new token.
27- const DEFAULT_REFRESH_TIME : Duration = Duration :: from_secs ( 120 ) ;
28-
2927impl BearerTokenCredentialPolicy {
3028 pub fn new < A , B > ( credential : Arc < dyn TokenCredential > , scopes : A ) -> Self
3129 where
@@ -63,16 +61,44 @@ impl Policy for BearerTokenCredentialPolicy {
6361 ) -> PolicyResult {
6462 let access_token = self . access_token . read ( ) . await ;
6563
66- if let Some ( token) = & ( * access_token) {
67- if token. is_expired ( Some ( DEFAULT_REFRESH_TIME ) ) {
64+ match access_token. as_ref ( ) {
65+ None => {
66+ // cache is empty. Upgrade the lock and acquire a token, provided another thread hasn't already done so
67+ drop ( access_token) ;
68+ let mut access_token = self . access_token . write ( ) . await ;
69+ if access_token. is_none ( ) {
70+ * access_token = Some ( self . credential . get_token ( & self . scopes ( ) ) . await ?) ;
71+ }
72+ }
73+ Some ( token) if should_refresh ( & token. expires_on ) => {
74+ // token is expired or within its refresh window. Upgrade the lock and
75+ // acquire a new token, provided another thread hasn't already done so
76+ let expires_on = token. expires_on ;
6877 drop ( access_token) ;
6978 let mut access_token = self . access_token . write ( ) . await ;
70- * access_token = Some ( self . credential . get_token ( & self . scopes ( ) ) . await ?) ;
79+ // access_token shouldn't be None here, but check anyway to guarantee unwrap won't panic
80+ if access_token. is_none ( ) || access_token. as_ref ( ) . unwrap ( ) . expires_on == expires_on
81+ {
82+ match self . credential . get_token ( & self . scopes ( ) ) . await {
83+ Ok ( new_token) => {
84+ * access_token = Some ( new_token) ;
85+ }
86+ Err ( e)
87+ if access_token. is_none ( )
88+ || expires_on <= OffsetDateTime :: now_utc ( ) =>
89+ {
90+ // propagate this error because we can't proceed without a new token
91+ return Err ( e) ;
92+ }
93+ Err ( _) => {
94+ // ignore this error because the cached token is still valid
95+ }
96+ }
97+ }
98+ }
99+ Some ( _) => {
100+ // do nothing; cached token is valid and not within its refresh window
71101 }
72- } else {
73- drop ( access_token) ;
74- let mut access_token = self . access_token . write ( ) . await ;
75- * access_token = Some ( self . credential . get_token ( & self . scopes ( ) ) . await ?) ;
76102 }
77103
78104 let access_token = self . access_token ( ) . await . ok_or_else ( || {
@@ -86,3 +112,161 @@ impl Policy for BearerTokenCredentialPolicy {
86112 next[ 0 ] . send ( ctx, request, & next[ 1 ..] ) . await
87113 }
88114}
115+
116+ fn should_refresh ( expires_on : & OffsetDateTime ) -> bool {
117+ * expires_on <= OffsetDateTime :: now_utc ( ) + Duration :: from_secs ( 300 )
118+ }
119+
120+ #[ cfg( test) ]
121+ mod tests {
122+ use super :: * ;
123+ use crate :: {
124+ credentials:: { Secret , TokenCredential } ,
125+ http:: {
126+ headers:: { Headers , AUTHORIZATION } ,
127+ policies:: Policy ,
128+ Request , Response , StatusCode ,
129+ } ,
130+ Bytes , Result ,
131+ } ;
132+ use async_trait:: async_trait;
133+ use azure_core_test:: http:: MockHttpClient ;
134+ use futures:: FutureExt ;
135+ use std:: sync:: {
136+ atomic:: { AtomicUsize , Ordering } ,
137+ Arc ,
138+ } ;
139+ use std:: time:: Duration ;
140+ use time:: OffsetDateTime ;
141+ use typespec_client_core:: http:: { policies:: TransportPolicy , Method , TransportOptions } ;
142+
143+ #[ derive( Debug , Clone ) ]
144+ struct MockCredential {
145+ calls : Arc < AtomicUsize > ,
146+ tokens : Arc < [ AccessToken ] > ,
147+ }
148+
149+ impl MockCredential {
150+ fn new ( tokens : & [ AccessToken ] ) -> Self {
151+ Self {
152+ calls : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
153+ tokens : tokens. into ( ) ,
154+ }
155+ }
156+
157+ fn get_token_calls ( & self ) -> usize {
158+ self . calls . load ( Ordering :: SeqCst )
159+ }
160+ }
161+
162+ // ensure the number of get_token() calls matches the number of tokens
163+ // in a test case i.e., that the policy called get_token() as expected
164+ impl Drop for MockCredential {
165+ fn drop ( & mut self ) {
166+ if !self . tokens . is_empty ( ) {
167+ assert_eq ! ( self . tokens. len( ) , self . calls. load( Ordering :: SeqCst ) ) ;
168+ }
169+ }
170+ }
171+
172+ #[ cfg_attr( target_arch = "wasm32" , async_trait( ?Send ) ) ]
173+ #[ cfg_attr( not( target_arch = "wasm32" ) , async_trait) ]
174+ impl TokenCredential for MockCredential {
175+ async fn get_token ( & self , _scopes : & [ & str ] ) -> Result < AccessToken > {
176+ let i = self . calls . fetch_add ( 1 , Ordering :: SeqCst ) ;
177+ self . tokens
178+ . get ( i)
179+ . ok_or_else ( || Error :: message ( ErrorKind :: Credential , "no more mock tokens" ) )
180+ . cloned ( )
181+ }
182+ }
183+
184+ #[ tokio:: test]
185+ async fn authn_error ( ) {
186+ // this mock's get_token() will return an error because it has no tokens
187+ let credential = MockCredential :: new ( & [ ] ) ;
188+ let policy = BearerTokenCredentialPolicy :: new ( Arc :: new ( credential) , [ "scope" ] ) ;
189+ let client = MockHttpClient :: new ( |_| panic ! ( "expected an error from get_token" ) ) ;
190+ let transport = Arc :: new ( TransportPolicy :: new ( TransportOptions :: new ( Arc :: new (
191+ client,
192+ ) ) ) ) ;
193+ let mut req = Request :: new ( "https://localhost" . parse ( ) . unwrap ( ) , Method :: Get ) ;
194+
195+ let err = policy
196+ . send ( & Context :: default ( ) , & mut req, & [ transport. clone ( ) ] )
197+ . await
198+ . expect_err ( "request should fail" ) ;
199+
200+ assert_eq ! ( ErrorKind :: Credential , * err. kind( ) ) ;
201+ }
202+
203+ async fn run_test ( tokens : & [ AccessToken ] ) {
204+ let credential = Arc :: new ( MockCredential :: new ( tokens) ) ;
205+ let policy = BearerTokenCredentialPolicy :: new ( credential. clone ( ) , [ "scope" ] ) ;
206+ let client = Arc :: new ( MockHttpClient :: new ( move |actual| {
207+ let credential = credential. clone ( ) ;
208+ async move {
209+ let authz = actual. headers ( ) . get_str ( & AUTHORIZATION ) ?;
210+ // e.g. if this is the first request, we expect 1 get_token call and tokens[0] in the header
211+ let i = credential. get_token_calls ( ) . saturating_sub ( 1 ) ;
212+ let expected = & credential. tokens [ i] ;
213+
214+ assert_eq ! ( format!( "Bearer {}" , expected. token. secret( ) ) , authz) ;
215+
216+ Ok ( Response :: from_bytes (
217+ StatusCode :: Ok ,
218+ Headers :: new ( ) ,
219+ Bytes :: new ( ) ,
220+ ) )
221+ }
222+ . boxed ( )
223+ } ) ) ;
224+ let transport = Arc :: new ( TransportPolicy :: new ( TransportOptions :: new ( client) ) ) ;
225+
226+ let mut handles = vec ! [ ] ;
227+ for _ in 0 ..4 {
228+ let policy = policy. clone ( ) ;
229+ let transport = transport. clone ( ) ;
230+ let handle = tokio:: spawn ( async move {
231+ let ctx = Context :: default ( ) ;
232+ let mut req = Request :: new ( "https://localhost" . parse ( ) . unwrap ( ) , Method :: Get ) ;
233+ policy
234+ . send ( & ctx, & mut req, & [ transport. clone ( ) ] )
235+ . await
236+ . expect ( "successful request" ) ;
237+ } ) ;
238+ handles. push ( handle) ;
239+ }
240+
241+ for handle in handles {
242+ tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , handle)
243+ . await
244+ . expect ( "task timed out after 2 seconds" )
245+ . expect ( "completed task" ) ;
246+ }
247+ }
248+
249+ #[ tokio:: test]
250+ async fn caches_token ( ) {
251+ run_test ( & [ AccessToken {
252+ token : Secret :: new ( "fake" . to_string ( ) ) ,
253+ expires_on : OffsetDateTime :: now_utc ( ) + Duration :: from_secs ( 3600 ) ,
254+ } ] )
255+ . await ;
256+ }
257+
258+ #[ tokio:: test]
259+ async fn refreshes_token ( ) {
260+ run_test ( & [
261+ AccessToken {
262+ token : Secret :: new ( "1" . to_string ( ) ) ,
263+ expires_on : OffsetDateTime :: now_utc ( ) - Duration :: from_secs ( 1 ) ,
264+ } ,
265+ AccessToken {
266+ token : Secret :: new ( "2" . to_string ( ) ) ,
267+ expires_on : OffsetDateTime :: now_utc ( ) + Duration :: from_secs ( 3600 ) ,
268+ } ,
269+ ] )
270+ . await ;
271+ }
272+ }
0 commit comments