@@ -33,7 +33,8 @@ use serde::Serialize;
3333
3434use super :: PubKey ;
3535
36- const JWK_REFRESH_INTERVAL : u64 = 15 ;
36+ const JWKS_REFRESH_TIMEOUT : u64 = 10 ;
37+ const JWKS_REFRESH_INTERVAL : u64 = 600 ;
3738
3839#[ derive( Debug , Serialize , Deserialize ) ]
3940pub struct JwkKey {
@@ -99,17 +100,17 @@ pub struct JwkKeyStore {
99100 cached_keys : Arc < RwLock < HashMap < String , PubKey > > > ,
100101 pub ( crate ) last_refreshed_at : RwLock < Option < Instant > > ,
101102 pub ( crate ) refresh_interval : Duration ,
103+ pub ( crate ) refresh_timeout : Duration ,
102104 pub ( crate ) load_keys_func : Option < Arc < dyn Fn ( ) -> HashMap < String , PubKey > + Send + Sync > > ,
103105}
104106
105107impl JwkKeyStore {
106108 pub fn new ( url : String ) -> Self {
107- let refresh_interval = Duration :: from_secs ( JWK_REFRESH_INTERVAL * 60 ) ;
108- let keys = Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ;
109109 Self {
110110 url,
111- cached_keys : keys,
112- refresh_interval,
111+ cached_keys : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
112+ refresh_interval : Duration :: from_secs ( JWKS_REFRESH_INTERVAL ) ,
113+ refresh_timeout : Duration :: from_secs ( JWKS_REFRESH_TIMEOUT ) ,
113114 last_refreshed_at : RwLock :: new ( None ) ,
114115 load_keys_func : None ,
115116 }
@@ -124,6 +125,16 @@ impl JwkKeyStore {
124125 self
125126 }
126127
128+ pub fn with_refresh_interval ( mut self , interval : u64 ) -> Self {
129+ self . refresh_interval = Duration :: from_secs ( interval) ;
130+ self
131+ }
132+
133+ pub fn with_refresh_timeout ( mut self , timeout : u64 ) -> Self {
134+ self . refresh_timeout = Duration :: from_secs ( timeout) ;
135+ self
136+ }
137+
127138 pub fn url ( & self ) -> String {
128139 self . url . clone ( )
129140 }
@@ -136,12 +147,19 @@ impl JwkKeyStore {
136147 return Ok ( load_keys_func ( ) ) ;
137148 }
138149
139- let response = reqwest:: get ( & self . url ) . await . map_err ( |e| {
150+ let client = reqwest:: Client :: builder ( )
151+ . timeout ( self . refresh_timeout )
152+ . build ( )
153+ . map_err ( |e| {
154+ ErrorCode :: InvalidConfig ( format ! ( "Failed to create jwks client: {}" , e) )
155+ } ) ?;
156+ let response = client. get ( & self . url ) . send ( ) . await . map_err ( |e| {
140157 ErrorCode :: AuthenticateFailure ( format ! ( "Could not download JWKS: {}" , e) )
141158 } ) ?;
142- let body = response. text ( ) . await . unwrap ( ) ;
143- let jwk_keys = serde_json:: from_str :: < JwkKeys > ( & body)
144- . map_err ( |e| ErrorCode :: InvalidConfig ( format ! ( "Failed to parse keys: {}" , e) ) ) ?;
159+ let jwk_keys: JwkKeys = response
160+ . json ( )
161+ . await
162+ . map_err ( |e| ErrorCode :: InvalidConfig ( format ! ( "Failed to parse JWKS: {}" , e) ) ) ?;
145163 let mut new_keys: HashMap < String , PubKey > = HashMap :: new ( ) ;
146164 for k in & jwk_keys. keys {
147165 new_keys. insert ( k. kid . to_string ( ) , k. get_public_key ( ) ?) ;
@@ -166,6 +184,7 @@ impl JwkKeyStore {
166184 let new_keys = match self . load_keys ( ) . await {
167185 Ok ( new_keys) => new_keys,
168186 Err ( err) => {
187+ warn ! ( "Failed to load JWKS: {}" , err) ;
169188 if !old_keys. is_empty ( ) {
170189 return Ok ( old_keys) ;
171190 }
@@ -177,9 +196,9 @@ impl JwkKeyStore {
177196 if !new_keys. keys ( ) . eq ( old_keys. keys ( ) ) {
178197 info ! ( "JWKS keys changed." ) ;
179198 }
180- * self . cached_keys . write ( ) = new_keys;
199+ * self . cached_keys . write ( ) = new_keys. clone ( ) ;
181200 self . last_refreshed_at . write ( ) . replace ( Instant :: now ( ) ) ;
182- Ok ( old_keys )
201+ Ok ( new_keys )
183202 }
184203
185204 #[ async_backtrace:: framed]
@@ -200,31 +219,12 @@ impl JwkKeyStore {
200219 }
201220 } ;
202221
203- // happy path: the key_id is found in the store
204- if let Some ( key) = keys. get ( & key_id) {
205- return Ok ( key. clone ( ) ) ;
222+ match keys. get ( & key_id) {
223+ None => Err ( ErrorCode :: AuthenticateFailure ( format ! (
224+ "key id {} not found in jwk store" ,
225+ key_id
226+ ) ) ) ,
227+ Some ( key) => Ok ( key. clone ( ) ) ,
206228 }
207-
208- // if the key_id is not set here, it might because the JWKS has been rotated, we need to refresh it.
209- warn ! (
210- "key_id {} not found in jwks store, try to reload keys" ,
211- key_id
212- ) ;
213- let keys = self
214- . load_keys_with_cache ( true )
215- . await
216- . map_err ( |e| e. add_message ( "failed to reload JWKS keys" ) ) ?;
217-
218- let key = match keys. get ( & key_id) {
219- None => {
220- return Err ( ErrorCode :: AuthenticateFailure ( format ! (
221- "key id {} not found in jwk store" ,
222- key_id
223- ) ) ) ;
224- }
225- Some ( key) => key. clone ( ) ,
226- } ;
227-
228- Ok ( key)
229229 }
230230}
0 commit comments