Skip to content

Commit 71c6b62

Browse files
jgodlewbpronan
andauthored
Add token refresh to CAS/Shard Client (#30)
Co-authored-by: Brian Ronan <[email protected]>
1 parent d087b27 commit 71c6b62

28 files changed

+769
-245
lines changed

Cargo.lock

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cas_client/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ rustls-pemfile = "2.0.0"
5252
hyper-rustls = { version = "0.26.0", features = ["http2"] }
5353
lz4 = "1.24.0"
5454
reqwest = "0.12.7"
55+
reqwest-middleware = "0.3.3"
5556
serde = { version = "1.0.210", features = ["derive"] }
5657
cas_types = { version = "0.1.0", path = "../cas_types" }
5758
url = "2.5.2"

cas_client/src/auth.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use anyhow::anyhow;
2+
use cas::auth::{AuthConfig, TokenProvider};
3+
use reqwest::header::HeaderValue;
4+
use reqwest::header::AUTHORIZATION;
5+
use reqwest::{Request, Response};
6+
use reqwest_middleware::{Middleware, Next};
7+
use std::sync::{Arc, Mutex};
8+
9+
/// AuthMiddleware is a thread-safe middleware that adds a CAS auth token to outbound requests.
10+
/// If the token it holds is expired, it will automatically be refreshed.
11+
pub struct AuthMiddleware {
12+
token_provider: Arc<Mutex<TokenProvider>>,
13+
}
14+
15+
impl AuthMiddleware {
16+
/// Fetches a token from our TokenProvider. This locks the TokenProvider as we might need
17+
/// to refresh the token if it has expired.
18+
///
19+
/// In the common case, this lock is held only to read the underlying token stored
20+
/// in memory. However, in the event of an expired token (e.g. once every 15 min),
21+
/// we will need to hold the lock while making a call to refresh the token
22+
/// (e.g. to a remote service). During this time, no other CAS requests can proceed
23+
/// from this client until the token has been fetched. This is expected/ok since we
24+
/// don't have a valid token and thus any calls would fail.
25+
fn get_token(&self) -> Result<String, anyhow::Error> {
26+
let mut provider = self
27+
.token_provider
28+
.lock()
29+
.map_err(|e| anyhow!("lock error: {e:?}"))?;
30+
provider
31+
.get_valid_token()
32+
.map_err(|e| anyhow!("couldn't get token: {e:?}"))
33+
}
34+
}
35+
36+
impl From<&AuthConfig> for AuthMiddleware {
37+
fn from(cfg: &AuthConfig) -> Self {
38+
Self {
39+
token_provider: Arc::new(Mutex::new(TokenProvider::new(cfg))),
40+
}
41+
}
42+
}
43+
44+
#[async_trait::async_trait]
45+
impl Middleware for AuthMiddleware {
46+
async fn handle(
47+
&self,
48+
mut req: Request,
49+
extensions: &mut hyper::http::Extensions,
50+
next: Next<'_>,
51+
) -> reqwest_middleware::Result<Response> {
52+
let token = self
53+
.get_token()
54+
.map_err(reqwest_middleware::Error::Middleware)?;
55+
56+
let headers = req.headers_mut();
57+
headers.insert(
58+
AUTHORIZATION,
59+
HeaderValue::from_str(&format!("Bearer {}", token)).unwrap(),
60+
);
61+
next.run(req, extensions).await
62+
}
63+
}

cas_client/src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ pub enum CasClientError {
6969
#[error("Parse Error: {0}")]
7070
ParseError(#[from] url::ParseError),
7171

72+
#[error("ReqwestMiddleware Error: {0}")]
73+
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
7274
#[error("Reqwest Error: {0}")]
7375
ReqwestError(#[from] reqwest::Error),
7476

cas_client/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
#![allow(dead_code)]
33

44
pub use crate::error::CasClientError;
5+
pub use auth::AuthMiddleware;
56
pub use caching_client::{CachingClient, DEFAULT_BLOCK_SIZE};
67
pub use interface::Client;
78
pub use local_client::LocalClient;
89
pub use merklehash::MerkleHash; // re-export since this is required for the client API.
910
pub use passthrough_staging_client::PassthroughStagingClient;
11+
pub use remote_client::build_reqwest_client;
1012
pub use remote_client::CASAPIClient;
1113
pub use remote_client::RemoteClient;
1214
pub use staging_client::{new_staging_client, new_staging_client_with_progressbar, StagingClient};
1315
pub use staging_trait::{Staging, StagingBypassable};
1416

17+
mod auth;
1518
mod caching_client;
1619
mod cas_connection_pool;
1720
mod client_adapter;

cas_client/src/remote_client.rs

Lines changed: 52 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,23 @@ use std::io::{Cursor, Write};
22

33
use anyhow::anyhow;
44
use bytes::Buf;
5-
use cas::key::Key;
6-
use cas_types::{QueryChunkResponse, QueryReconstructionResponse, UploadXorbResponse};
7-
use reqwest::{
8-
header::{HeaderMap, HeaderValue},
9-
StatusCode, Url,
10-
};
5+
use bytes::Bytes;
6+
use reqwest::{StatusCode, Url};
7+
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware};
118
use serde::{de::DeserializeOwned, Serialize};
9+
use tracing::{debug, warn};
1210

13-
use bytes::Bytes;
11+
use cas::auth::AuthConfig;
12+
use cas::key::Key;
1413
use cas_object::CasObject;
1514
use cas_types::CASReconstructionTerm;
16-
use tracing::{debug, warn};
17-
18-
use crate::{error::Result, CasClientError};
19-
15+
use cas_types::{QueryChunkResponse, QueryReconstructionResponse, UploadXorbResponse};
16+
use error_printer::OptionPrinter;
2017
use merklehash::MerkleHash;
2118

2219
use crate::Client;
20+
use crate::{error::Result, AuthMiddleware, CasClientError};
21+
2322
pub const CAS_ENDPOINT: &str = "http://localhost:8080";
2423
pub const PREFIX_DEFAULT: &str = "default";
2524

@@ -84,44 +83,37 @@ impl Client for RemoteClient {
8483
}
8584

8685
impl RemoteClient {
87-
pub async fn from_config(endpoint: String, token: Option<String>) -> Self {
86+
pub async fn from_config(endpoint: String, auth_config: &Option<AuthConfig>) -> Self {
8887
Self {
89-
client: CASAPIClient::new(&endpoint, token),
88+
client: CASAPIClient::new(&endpoint, auth_config),
9089
}
9190
}
9291
}
9392

9493
#[derive(Debug)]
9594
pub struct CASAPIClient {
96-
client: reqwest::Client,
95+
client: ClientWithMiddleware,
9796
endpoint: String,
98-
token: Option<String>,
9997
}
10098

10199
impl Default for CASAPIClient {
102100
fn default() -> Self {
103-
Self::new(CAS_ENDPOINT, None)
101+
Self::new(CAS_ENDPOINT, &None)
104102
}
105103
}
106104

107105
impl CASAPIClient {
108-
pub fn new(endpoint: &str, token: Option<String>) -> Self {
109-
let client = reqwest::Client::builder().build().unwrap();
106+
pub fn new(endpoint: &str, auth_config: &Option<AuthConfig>) -> Self {
107+
let client = build_reqwest_client(auth_config).unwrap();
110108
Self {
111109
client,
112110
endpoint: endpoint.to_string(),
113-
token,
114111
}
115112
}
116113

117114
pub async fn exists(&self, key: &Key) -> Result<bool> {
118115
let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?;
119-
let response = self
120-
.client
121-
.head(url)
122-
.headers(self.request_headers())
123-
.send()
124-
.await?;
116+
let response = self.client.head(url).send().await?;
125117
match response.status() {
126118
StatusCode::OK => Ok(true),
127119
StatusCode::NOT_FOUND => Ok(false),
@@ -133,12 +125,7 @@ impl CASAPIClient {
133125

134126
pub async fn get_length(&self, key: &Key) -> Result<Option<u64>> {
135127
let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?;
136-
let response = self
137-
.client
138-
.head(url)
139-
.headers(self.request_headers())
140-
.send()
141-
.await?;
128+
let response = self.client.head(url).send().await?;
142129
let status = response.status();
143130
if status == StatusCode::NOT_FOUND {
144131
return Ok(None);
@@ -189,13 +176,7 @@ impl CASAPIClient {
189176
writer.set_position(0);
190177
let data = writer.into_inner();
191178

192-
let response = self
193-
.client
194-
.post(url)
195-
.headers(self.request_headers())
196-
.body(data)
197-
.send()
198-
.await?;
179+
let response = self.client.post(url).body(data).send().await?;
199180
let response_body = response.bytes().await?;
200181
let response_parsed: UploadXorbResponse = serde_json::from_reader(response_body.reader())?;
201182

@@ -247,12 +228,7 @@ impl CASAPIClient {
247228
file_id.hex()
248229
))?;
249230

250-
let response = self
251-
.client
252-
.get(url)
253-
.headers(self.request_headers())
254-
.send()
255-
.await?;
231+
let response = self.client.get(url).send().await?;
256232
let response_body = response.bytes().await?;
257233
let response_parsed: QueryReconstructionResponse =
258234
serde_json::from_reader(response_body.reader())?;
@@ -262,29 +238,13 @@ impl CASAPIClient {
262238

263239
pub async fn shard_query_chunk(&self, key: &Key) -> Result<QueryChunkResponse> {
264240
let url = Url::parse(&format!("{}/chunk/{key}", self.endpoint))?;
265-
let response = self
266-
.client
267-
.get(url)
268-
.headers(self.request_headers())
269-
.send()
270-
.await?;
241+
let response = self.client.get(url).send().await?;
271242
let response_body = response.bytes().await?;
272243
let response_parsed: QueryChunkResponse = serde_json::from_reader(response_body.reader())?;
273244

274245
Ok(response_parsed)
275246
}
276247

277-
fn request_headers(&self) -> HeaderMap {
278-
let mut headers = HeaderMap::new();
279-
if let Some(tok) = &self.token {
280-
headers.insert(
281-
"Authorization",
282-
HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap(),
283-
);
284-
}
285-
headers
286-
}
287-
288248
async fn post_json<ReqT, RespT>(&self, url: Url, request_body: &ReqT) -> Result<RespT>
289249
where
290250
ReqT: Serialize,
@@ -330,22 +290,49 @@ async fn get_one(term: &CASReconstructionTerm) -> Result<Bytes> {
330290
Ok(Bytes::from(sliced))
331291
}
332292

293+
/// builds the client to talk to CAS.
294+
pub fn build_reqwest_client(
295+
auth_config: &Option<AuthConfig>,
296+
) -> std::result::Result<ClientWithMiddleware, reqwest::Error> {
297+
let auth_middleware = auth_config
298+
.as_ref()
299+
.map(AuthMiddleware::from)
300+
.info_none("CAS auth disabled");
301+
let reqwest_client = reqwest::Client::builder().build()?;
302+
Ok(ClientBuilder::new(reqwest_client)
303+
.maybe_with(auth_middleware)
304+
.build())
305+
}
306+
307+
/// Helper trait to allow the reqwest_middleware client to optionally add a middleware.
308+
trait OptionalMiddleware {
309+
fn maybe_with<M: Middleware>(self, middleware: Option<M>) -> Self;
310+
}
311+
312+
impl OptionalMiddleware for ClientBuilder {
313+
fn maybe_with<M: Middleware>(self, middleware: Option<M>) -> Self {
314+
match middleware {
315+
Some(m) => self.with(m),
316+
None => self,
317+
}
318+
}
319+
}
320+
333321
#[cfg(test)]
334322
mod tests {
335-
336-
use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB};
337-
use merklehash::DataHash;
338323
use rand::Rng;
339324
use tracing_test::traced_test;
340325

341326
use super::*;
327+
use merkledb::{prelude::MerkleDBHighLevelMethodsV1, Chunk, MerkleMemDB};
328+
use merklehash::DataHash;
342329

343330
#[ignore]
344331
#[traced_test]
345332
#[tokio::test]
346333
async fn test_basic_put() {
347334
// Arrange
348-
let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string(), None).await;
335+
let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string(), &None).await;
349336
let prefix = PREFIX_DEFAULT;
350337
let (hash, data, chunk_boundaries) = gen_dummy_xorb(3, 10248, true);
351338

0 commit comments

Comments
 (0)