Skip to content

Commit c2294d8

Browse files
authored
feat(query): add config jwks_refresh_interval & jwks_refresh_timeout (#17087)
* feat(query): add config jwks_refresh_interval & jwks_refresh_timeout * fix: remove force reload when key not found * z * z * z * z * z * z * z * z * z
1 parent 85c4caf commit c2294d8

File tree

8 files changed

+77
-73
lines changed

8 files changed

+77
-73
lines changed

.github/actions/setup_build_tool/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ runs:
3636
EOF
3737
3838
RUNNER_PROVIDER="${RUNNER_PROVIDER:-github}"
39+
export SCCACHE_IDLE_TIMEOUT=0
3940
case ${RUNNER_PROVIDER} in
4041
aws)
4142
echo "setting up sccache for AWS S3..."

src/query/config/src/config.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1528,7 +1528,14 @@ pub struct QueryConfig {
15281528
#[clap(long, value_name = "VALUE", default_value_t)]
15291529
pub jwt_key_file: String,
15301530

1531-
/// If there are multiple trusted jwt provider put it into additional_jwt_key_files configuration
1531+
/// Interval in seconds to refresh jwks
1532+
#[clap(long, value_name = "VALUE", default_value = "600")]
1533+
pub jwks_refresh_interval: u64,
1534+
1535+
/// Timeout in seconds to refresh jwks
1536+
#[clap(long, value_name = "VALUE", default_value = "10")]
1537+
pub jwks_refresh_timeout: u64,
1538+
15321539
#[clap(skip)]
15331540
pub jwt_key_files: Vec<String>,
15341541

@@ -1754,6 +1761,8 @@ impl TryInto<InnerQueryConfig> for QueryConfig {
17541761
max_storage_io_requests: self.max_storage_io_requests,
17551762
jwt_key_file: self.jwt_key_file,
17561763
jwt_key_files: self.jwt_key_files,
1764+
jwks_refresh_interval: self.jwks_refresh_interval,
1765+
jwks_refresh_timeout: self.jwks_refresh_timeout,
17571766
default_storage_format: self.default_storage_format,
17581767
default_compression: self.default_compression,
17591768
builtin: BuiltInConfig {
@@ -1845,6 +1854,8 @@ impl From<InnerQueryConfig> for QueryConfig {
18451854
max_storage_io_requests: inner.max_storage_io_requests,
18461855
jwt_key_file: inner.jwt_key_file,
18471856
jwt_key_files: inner.jwt_key_files,
1857+
jwks_refresh_interval: inner.jwks_refresh_interval,
1858+
jwks_refresh_timeout: inner.jwks_refresh_timeout,
18481859
default_storage_format: inner.default_storage_format,
18491860
default_compression: inner.default_compression,
18501861
users: inner.builtin.users,

src/query/config/src/inner.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ pub struct QueryConfig {
215215

216216
pub jwt_key_file: String,
217217
pub jwt_key_files: Vec<String>,
218+
pub jwks_refresh_interval: u64,
219+
pub jwks_refresh_timeout: u64,
218220
pub default_storage_format: String,
219221
pub default_compression: String,
220222
pub builtin: BuiltInConfig,
@@ -301,6 +303,8 @@ impl Default for QueryConfig {
301303
max_storage_io_requests: None,
302304
jwt_key_file: "".to_string(),
303305
jwt_key_files: Vec::new(),
306+
jwks_refresh_interval: 600,
307+
jwks_refresh_timeout: 10,
304308
default_storage_format: "auto".to_string(),
305309
default_compression: "auto".to_string(),
306310
builtin: BuiltInConfig::default(),

src/query/service/src/auth.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ impl AuthMgr {
8585
jwt_auth: JwtAuthenticator::create(
8686
cfg.query.jwt_key_file.clone(),
8787
cfg.query.jwt_key_files.clone(),
88+
cfg.query.jwks_refresh_interval,
89+
cfg.query.jwks_refresh_timeout,
8890
),
8991
})
9092
}

src/query/service/tests/it/storages/testdata/configs_table_basic.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ DB.Table: 'system'.'configs', Table: configs-table_id:1, ver:0, Engine: SystemCo
100100
| 'query' | 'http_handler_tls_server_root_ca_cert' | '' | '' |
101101
| 'query' | 'internal_enable_sandbox_tenant' | 'false' | '' |
102102
| 'query' | 'internal_merge_on_read_mutation' | 'false' | '' |
103+
| 'query' | 'jwks_refresh_interval' | '600' | '' |
104+
| 'query' | 'jwks_refresh_timeout' | '10' | '' |
103105
| 'query' | 'jwt_key_file' | '' | '' |
104106
| 'query' | 'jwt_key_files' | '' | '' |
105107
| 'query' | 'management_mode' | 'false' | '' |

src/query/users/src/jwt/authenticator.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,30 @@ impl CustomClaims {
7878
}
7979

8080
impl JwtAuthenticator {
81-
pub fn create(jwt_key_file: String, jwt_key_files: Vec<String>) -> Option<Self> {
81+
pub fn create(
82+
jwt_key_file: String,
83+
jwt_key_files: Vec<String>,
84+
jwks_refresh_interval: u64,
85+
jwks_refresh_timeout: u64,
86+
) -> Option<Self> {
8287
if jwt_key_file.is_empty() && jwt_key_files.is_empty() {
8388
return None;
8489
}
8590
// init a vec of key store
86-
let mut key_stores = vec![jwk::JwkKeyStore::new(jwt_key_file)];
91+
let mut key_stores = vec![];
92+
if !jwt_key_file.is_empty() {
93+
key_stores.push(
94+
jwk::JwkKeyStore::new(jwt_key_file)
95+
.with_refresh_interval(jwks_refresh_interval)
96+
.with_refresh_timeout(jwks_refresh_timeout),
97+
);
98+
}
8799
for u in jwt_key_files {
88-
key_stores.push(jwk::JwkKeyStore::new(u))
100+
key_stores.push(
101+
jwk::JwkKeyStore::new(u)
102+
.with_refresh_interval(jwks_refresh_interval)
103+
.with_refresh_timeout(jwks_refresh_timeout),
104+
);
89105
}
90106
Some(JwtAuthenticator { key_stores })
91107
}

src/query/users/src/jwt/jwk.rs

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ use serde::Serialize;
3333

3434
use super::PubKey;
3535

36-
const JWK_REFRESH_INTERVAL: u64 = 15;
36+
const JWKS_REFRESH_TIMEOUT: u64 = 10;
37+
const JWKS_REFRESH_INTERVAL: u64 = 600;
3738

3839
#[derive(Debug, Serialize, Deserialize)]
3940
pub struct JwkKey {
@@ -99,17 +100,17 @@ pub struct JwkKeyStore {
99100
cached_keys: Arc<RwLock<HashMap<String, PubKey>>>,
100101
pub(crate) last_refreshed_at: RwLock<Option<Instant>>,
101102
pub(crate) refresh_interval: Duration,
103+
pub(crate) refresh_timeout: Duration,
102104
pub(crate) load_keys_func: Option<Arc<dyn Fn() -> HashMap<String, PubKey> + Send + Sync>>,
103105
}
104106

105107
impl JwkKeyStore {
106108
pub fn new(url: String) -> Self {
107-
let refresh_interval = Duration::from_secs(JWK_REFRESH_INTERVAL * 60);
108-
let keys = Arc::new(RwLock::new(HashMap::new()));
109109
Self {
110110
url,
111-
cached_keys: keys,
112-
refresh_interval,
111+
cached_keys: Arc::new(RwLock::new(HashMap::new())),
112+
refresh_interval: Duration::from_secs(JWKS_REFRESH_INTERVAL),
113+
refresh_timeout: Duration::from_secs(JWKS_REFRESH_TIMEOUT),
113114
last_refreshed_at: RwLock::new(None),
114115
load_keys_func: None,
115116
}
@@ -124,6 +125,16 @@ impl JwkKeyStore {
124125
self
125126
}
126127

128+
pub fn with_refresh_interval(mut self, interval: u64) -> Self {
129+
self.refresh_interval = Duration::from_secs(interval);
130+
self
131+
}
132+
133+
pub fn with_refresh_timeout(mut self, timeout: u64) -> Self {
134+
self.refresh_timeout = Duration::from_secs(timeout);
135+
self
136+
}
137+
127138
pub fn url(&self) -> String {
128139
self.url.clone()
129140
}
@@ -136,12 +147,19 @@ impl JwkKeyStore {
136147
return Ok(load_keys_func());
137148
}
138149

139-
let response = reqwest::get(&self.url).await.map_err(|e| {
150+
let client = reqwest::Client::builder()
151+
.timeout(self.refresh_timeout)
152+
.build()
153+
.map_err(|e| {
154+
ErrorCode::InvalidConfig(format!("Failed to create jwks client: {}", e))
155+
})?;
156+
let response = client.get(&self.url).send().await.map_err(|e| {
140157
ErrorCode::AuthenticateFailure(format!("Could not download JWKS: {}", e))
141158
})?;
142-
let body = response.text().await.unwrap();
143-
let jwk_keys = serde_json::from_str::<JwkKeys>(&body)
144-
.map_err(|e| ErrorCode::InvalidConfig(format!("Failed to parse keys: {}", e)))?;
159+
let jwk_keys: JwkKeys = response
160+
.json()
161+
.await
162+
.map_err(|e| ErrorCode::InvalidConfig(format!("Failed to parse JWKS: {}", e)))?;
145163
let mut new_keys: HashMap<String, PubKey> = HashMap::new();
146164
for k in &jwk_keys.keys {
147165
new_keys.insert(k.kid.to_string(), k.get_public_key()?);
@@ -166,6 +184,7 @@ impl JwkKeyStore {
166184
let new_keys = match self.load_keys().await {
167185
Ok(new_keys) => new_keys,
168186
Err(err) => {
187+
warn!("Failed to load JWKS: {}", err);
169188
if !old_keys.is_empty() {
170189
return Ok(old_keys);
171190
}
@@ -177,9 +196,9 @@ impl JwkKeyStore {
177196
if !new_keys.keys().eq(old_keys.keys()) {
178197
info!("JWKS keys changed.");
179198
}
180-
*self.cached_keys.write() = new_keys;
199+
*self.cached_keys.write() = new_keys.clone();
181200
self.last_refreshed_at.write().replace(Instant::now());
182-
Ok(old_keys)
201+
Ok(new_keys)
183202
}
184203

185204
#[async_backtrace::framed]
@@ -200,31 +219,12 @@ impl JwkKeyStore {
200219
}
201220
};
202221

203-
// happy path: the key_id is found in the store
204-
if let Some(key) = keys.get(&key_id) {
205-
return Ok(key.clone());
222+
match keys.get(&key_id) {
223+
None => Err(ErrorCode::AuthenticateFailure(format!(
224+
"key id {} not found in jwk store",
225+
key_id
226+
))),
227+
Some(key) => Ok(key.clone()),
206228
}
207-
208-
// if the key_id is not set here, it might because the JWKS has been rotated, we need to refresh it.
209-
warn!(
210-
"key_id {} not found in jwks store, try to reload keys",
211-
key_id
212-
);
213-
let keys = self
214-
.load_keys_with_cache(true)
215-
.await
216-
.map_err(|e| e.add_message("failed to reload JWKS keys"))?;
217-
218-
let key = match keys.get(&key_id) {
219-
None => {
220-
return Err(ErrorCode::AuthenticateFailure(format!(
221-
"key id {} not found in jwk store",
222-
key_id
223-
)));
224-
}
225-
Some(key) => key.clone(),
226-
};
227-
228-
Ok(key)
229229
}
230230
}

src/query/users/tests/it/jwt/authenticator.rs

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,11 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
use std::collections::HashMap;
16-
use std::sync::atomic::AtomicUsize;
17-
use std::sync::atomic::Ordering;
18-
use std::sync::Arc;
19-
2015
use base64::engine::general_purpose;
2116
use base64::prelude::*;
2217
use databend_common_base::base::tokio;
2318
use databend_common_exception::Result;
24-
use databend_common_users::JwkKeyStore;
2519
use databend_common_users::JwtAuthenticator;
26-
use databend_common_users::PubKey;
2720
use jwt_simple::prelude::*;
2821
use wiremock::matchers::method;
2922
use wiremock::matchers::path;
@@ -60,7 +53,7 @@ async fn test_parse_non_custom_claim() -> Result<()> {
6053
.mount(&server)
6154
.await;
6255
let first_url = format!("http://{}{}", server.address(), json_path);
63-
let auth = JwtAuthenticator::create(first_url, vec![]).unwrap();
56+
let auth = JwtAuthenticator::create(first_url, vec![], 600, 10).unwrap();
6457
let user_name = "test-user2";
6558
let my_additional_data = MyAdditionalData {
6659
user_is_admin: false,
@@ -74,28 +67,3 @@ async fn test_parse_non_custom_claim() -> Result<()> {
7467
assert_eq!(res.custom.role, None);
7568
Ok(())
7669
}
77-
78-
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
79-
async fn test_jwk_key_store_retry_on_key_not_found() -> Result<()> {
80-
let func_calls = Arc::new(AtomicUsize::new(0));
81-
let func_calls_cloned = func_calls.clone();
82-
83-
let mock_load_keys = Arc::new(move || -> HashMap<String, PubKey> {
84-
let mut keys_map = HashMap::new();
85-
keys_map.insert(
86-
"key1".to_string(),
87-
PubKey::RSA256(RS256KeyPair::generate(2048).unwrap().public_key().into()),
88-
);
89-
func_calls_cloned.fetch_add(1, Ordering::SeqCst);
90-
keys_map
91-
});
92-
let store = JwkKeyStore::new("".to_string()).with_load_keys_func(mock_load_keys);
93-
94-
let r = store.get_key(Some("key2".to_string())).await;
95-
assert_eq!(
96-
r.unwrap_err().message(),
97-
"key id key2 not found in jwk store"
98-
);
99-
assert_eq!(func_calls.load(Ordering::SeqCst), 2);
100-
Ok(())
101-
}

0 commit comments

Comments
 (0)