Skip to content

Commit 6950115

Browse files
authored
Plumb endpoint and access token (#19)
* Basic plumbing of the endpoint and access token through to calls to xetcas
1 parent 148b720 commit 6950115

File tree

12 files changed

+109
-62
lines changed

12 files changed

+109
-62
lines changed

cas_client/src/remote_client.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use anyhow::anyhow;
44
use bytes::Buf;
55
use cas::key::Key;
66
use cas_types::{QueryChunkResponse, QueryReconstructionResponse, UploadXorbResponse};
7-
use reqwest::{StatusCode, Url};
7+
use reqwest::{header::{HeaderMap, HeaderValue}, StatusCode, Url};
88
use serde::{de::DeserializeOwned, Serialize};
99

1010
use bytes::Bytes;
@@ -82,9 +82,9 @@ impl Client for RemoteClient {
8282
}
8383

8484
impl RemoteClient {
85-
pub async fn from_config(endpoint: String) -> Self {
86-
Self {
87-
client: CASAPIClient::new(&endpoint),
85+
pub async fn from_config(endpoint: String, token: Option<String>) -> Self {
86+
Self {
87+
client: CASAPIClient::new(&endpoint, token)
8888
}
8989
}
9090
}
@@ -93,20 +93,22 @@ impl RemoteClient {
9393
pub struct CASAPIClient {
9494
client: reqwest::Client,
9595
endpoint: String,
96+
token: Option<String>,
9697
}
9798

9899
impl Default for CASAPIClient {
99100
fn default() -> Self {
100-
Self::new(CAS_ENDPOINT)
101+
Self::new(CAS_ENDPOINT, None)
101102
}
102103
}
103104

104105
impl CASAPIClient {
105-
pub fn new(endpoint: &str) -> Self {
106+
pub fn new(endpoint: &str, token: Option<String>) -> Self {
106107
let client = reqwest::Client::builder().build().unwrap();
107108
Self {
108109
client,
109110
endpoint: endpoint.to_string(),
111+
token
110112
}
111113
}
112114

@@ -222,12 +224,17 @@ impl CASAPIClient {
222224
/// Reconstruct the file
223225
async fn reconstruct_file(&self, file_id: &MerkleHash) -> Result<QueryReconstructionResponse> {
224226
let url = Url::parse(&format!(
225-
"{}/reconstruction/{}",
226-
self.endpoint,
227+
"{}/reconstruction/{}",
228+
self.endpoint,
227229
file_id.hex()
228230
))?;
229231

230-
let response = self.client.get(url).send().await?;
232+
let mut headers = HeaderMap::new();
233+
if let Some(tok) = &self.token {
234+
headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", tok)).unwrap());
235+
}
236+
237+
let response = self.client.get(url).headers(headers).send().await?;
231238
let response_body = response.bytes().await?;
232239
let response_parsed: QueryReconstructionResponse =
233240
serde_json::from_reader(response_body.reader())?;
@@ -290,7 +297,7 @@ mod tests {
290297
#[tokio::test]
291298
async fn test_basic_put() {
292299
// Arrange
293-
let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string()).await;
300+
let rc = RemoteClient::from_config(CAS_ENDPOINT.to_string(), None).await;
294301
let prefix = PREFIX_DEFAULT;
295302
let (hash, data, chunk_boundaries) = gen_dummy_xorb(3, 10248, true);
296303

cas_types/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ pub struct CASReconstructionTerm {
2222
pub hash: HexMerkleHash,
2323
pub unpacked_length: u32,
2424
pub range: Range,
25-
pub range_start_offset: u32,
25+
// TODO: disabling until https://github.com/huggingface-internal/xetcas/pull/31/files
26+
// is merged.
27+
// pub range_start_offset: u32,
2628
pub url: String,
2729
}
2830

data/src/bin/example.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ fn default_clean_config() -> Result<TranslatorConfig> {
7474
cas_storage_config: StorageConfig {
7575
endpoint: Endpoint::FileSystem(path.join("xorbs")),
7676
auth: Auth {
77-
user_id: "".into(),
78-
login_id: "".into(),
77+
token: None,
7978
},
8079
prefix: "default".into(),
8180
cache_config: Some(CacheConfig {
@@ -88,8 +87,7 @@ fn default_clean_config() -> Result<TranslatorConfig> {
8887
shard_storage_config: StorageConfig {
8988
endpoint: Endpoint::FileSystem(path.join("xorbs")),
9089
auth: Auth {
91-
user_id: "".into(),
92-
login_id: "".into(),
90+
token: None,
9391
},
9492
prefix: "default-merkledb".into(),
9593
cache_config: Some(CacheConfig {
@@ -123,8 +121,7 @@ fn default_smudge_config() -> Result<TranslatorConfig> {
123121
cas_storage_config: StorageConfig {
124122
endpoint: Endpoint::FileSystem(path.join("xorbs")),
125123
auth: Auth {
126-
user_id: "".into(),
127-
login_id: "".into(),
124+
token: None,
128125
},
129126
prefix: "default".into(),
130127
cache_config: Some(CacheConfig {
@@ -137,8 +134,7 @@ fn default_smudge_config() -> Result<TranslatorConfig> {
137134
shard_storage_config: StorageConfig {
138135
endpoint: Endpoint::FileSystem(path.join("xorbs")),
139136
auth: Auth {
140-
user_id: "".into(),
141-
login_id: "".into(),
137+
token: None,
142138
},
143139
prefix: "default-merkledb".into(),
144140
cache_config: Some(CacheConfig {
@@ -234,7 +230,7 @@ async fn smudge(mut reader: impl Read, mut writer: impl Write) -> Result<()> {
234230
let translator = PointerFileTranslator::new(default_smudge_config()?).await?;
235231

236232
translator
237-
.smudge_file_from_pointer(&pointer_file, &mut writer, None)
233+
.smudge_file_from_pointer(&pointer_file, &mut writer, None, None, None)
238234
.await?;
239235

240236
Ok(())

data/src/cas_interface.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ pub(crate) async fn create_cas_client(
3636
unreachable!();
3737
};
3838

39-
// Auth info.
40-
let _user_id = &cas_storage_config.auth.user_id;
41-
let _auth = &cas_storage_config.auth.login_id;
4239

4340
// Usage tracking.
4441
let _repo_paths = maybe_repo_info
@@ -49,7 +46,7 @@ pub(crate) async fn create_cas_client(
4946

5047
// Raw remote client.
5148
let remote_client = Arc::new(
52-
RemoteClient::from_config(endpoint.to_string()).await,
49+
RemoteClient::from_config(endpoint.to_string(), cas_storage_config.auth.token.clone()).await,
5350
);
5451

5552
// Try add in caching capability.

data/src/configurations.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ pub enum Endpoint {
1111

1212
#[derive(Debug)]
1313
pub struct Auth {
14-
pub user_id: String,
15-
pub login_id: String,
14+
pub token: Option<String>,
1615
}
1716

1817
#[derive(Debug)]

data/src/data_processing.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,10 @@ impl PointerFileTranslator {
313313
pointer: &PointerFile,
314314
writer: &mut impl std::io::Write,
315315
range: Option<(usize, usize)>,
316+
endpoint: Option<String>,
317+
token: Option<String>,
316318
) -> Result<()> {
317-
self.smudge_file_from_hash(&pointer.hash()?, writer, range)
319+
self.smudge_file_from_hash(&pointer.hash()?, writer, range, endpoint, token)
318320
.await
319321
}
320322

@@ -323,13 +325,21 @@ impl PointerFileTranslator {
323325
file_id: &MerkleHash,
324326
writer: &mut impl std::io::Write,
325327
_range: Option<(usize, usize)>,
328+
endpoint: Option<String>,
329+
token: Option<String>,
326330
) -> Result<()> {
327331
let endpoint = match &self.config.cas_storage_config.endpoint {
328-
Endpoint::Server(endpoint) => endpoint.clone(),
332+
Endpoint::Server(config_endpoint) => {
333+
if let Some(endpoint) = endpoint {
334+
endpoint
335+
} else {
336+
config_endpoint.clone()
337+
}
338+
},
329339
Endpoint::FileSystem(_) => panic!("aaaaaaaa no server"),
330340
};
331341

332-
let rc = CASAPIClient::new(&endpoint);
342+
let rc = CASAPIClient::new(&endpoint, token);
333343

334344
rc.write_file(file_id, writer).await?;
335345

data/src/shard_interface.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub async fn create_shard_client(
3939
) -> Result<Arc<dyn ShardClientInterface>> {
4040
info!("Shard endpoint = {:?}", shard_storage_config.endpoint);
4141
let client: Arc<dyn ShardClientInterface> = match &shard_storage_config.endpoint {
42-
Server(endpoint) => Arc::new(HttpShardClient::new(endpoint)),
42+
Server(endpoint) => Arc::new(HttpShardClient::new(endpoint, shard_storage_config.auth.token.clone())),
4343
FileSystem(path) => Arc::new(LocalShardClient::new(path).await?),
4444
};
4545

hf_xet/Cargo.lock

Lines changed: 24 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hf_xet/src/config.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use data::{DEFAULT_BLOCK_SIZE, errors};
55

66
pub const SMALL_FILE_THRESHOLD: usize = 1;
77

8-
pub fn default_config(endpoint: String) -> errors::Result<TranslatorConfig> {
8+
pub fn default_config(endpoint: String, token: Option<String>) -> errors::Result<TranslatorConfig> {
99
let path = current_dir()?.join(".xet");
1010
fs::create_dir_all(&path)?;
1111

@@ -14,8 +14,7 @@ pub fn default_config(endpoint: String) -> errors::Result<TranslatorConfig> {
1414
cas_storage_config: StorageConfig {
1515
endpoint: Endpoint::Server(endpoint.clone()),
1616
auth: Auth {
17-
user_id: "".into(),
18-
login_id: "".into(),
17+
token: token.clone(),
1918
},
2019
prefix: "default".into(),
2120
cache_config: Some(CacheConfig {
@@ -28,8 +27,7 @@ pub fn default_config(endpoint: String) -> errors::Result<TranslatorConfig> {
2827
shard_storage_config: StorageConfig {
2928
endpoint: Endpoint::Server(endpoint),
3029
auth: Auth {
31-
user_id: "".into(),
32-
login_id: "".into(),
30+
token: token,
3331
},
3432
prefix: "default-merkledb".into(),
3533
cache_config: Some(CacheConfig {

hf_xet/src/data_client.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@ use crate::config::default_config;
1212
pub const MAX_CONCURRENT_UPLOADS: usize = 8; // TODO
1313
pub const MAX_CONCURRENT_DOWNLOADS: usize = 8; // TODO
1414

15-
const DEFAULT_CAS_ENDPOINT: &str = "https://cas-server.us.dev.moon.huggingface.tech";
15+
const DEFAULT_CAS_ENDPOINT: &str = "http://localhost:8080";
1616
const READ_BLOCK_SIZE: usize = 1024 * 1024;
1717

18-
pub async fn upload_async(file_paths: Vec<String>) -> errors::Result<Vec<PointerFile>> {
18+
pub async fn upload_async(file_paths: Vec<String>, endpoint: Option<String>, token: Option<String>) -> errors::Result<Vec<PointerFile>> {
1919
// chunk files
2020
// produce Xorbs + Shards
2121
// upload shards and xorbs
2222
// for each file, return the filehash
23+
let endpoint = endpoint.unwrap_or(DEFAULT_CAS_ENDPOINT.to_string());
2324

24-
let config = default_config(DEFAULT_CAS_ENDPOINT.to_string())?;
25+
let config = default_config(endpoint, token)?;
2526
let processor = Arc::new(PointerFileTranslator::new(config).await?);
2627
let processor = &processor;
2728
// for all files, clean them, producing pointer files.
@@ -45,16 +46,20 @@ pub async fn upload_async(file_paths: Vec<String>) -> errors::Result<Vec<Pointer
4546
Ok(pointers)
4647
}
4748

48-
pub async fn download_async(pointer_files: Vec<PointerFile>) -> errors::Result<Vec<String>> {
49-
let config = default_config(DEFAULT_CAS_ENDPOINT.to_string())?;
49+
pub async fn download_async(pointer_files: Vec<PointerFile>, endpoint: Option<String>, token: Option<String>) -> errors::Result<Vec<String>> {
50+
let config = default_config(endpoint.clone().unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()), token.clone())?;
5051
let processor = Arc::new(PointerFileTranslator::new(config).await?);
5152
let processor = &processor;
5253
let paths = tokio_par_for_each(
5354
pointer_files,
5455
MAX_CONCURRENT_DOWNLOADS,
55-
|pointer_file, _| async move {
56-
let proc = processor.clone();
57-
smudge_file(&proc, &pointer_file).await
56+
|pointer_file, _| {
57+
let tok = token.clone();
58+
let end = endpoint.clone();
59+
async move {
60+
let proc = processor.clone();
61+
smudge_file(&proc, &pointer_file, end.clone(), tok.clone()).await
62+
}
5863
},
5964
).await.map_err(|e| match e {
6065
ParallelError::JoinError => {
@@ -87,13 +92,13 @@ async fn clean_file(processor: &PointerFileTranslator, f: String) -> errors::Res
8792
Ok(pf)
8893
}
8994

90-
async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile) -> errors::Result<String> {
95+
async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile, endpoint: Option<String>, token: Option<String>) -> errors::Result<String> {
9196
let path = PathBuf::from(pointer_file.path());
9297
if let Some(parent_dir) = path.parent() {
9398
fs::create_dir_all(parent_dir)?;
9499
}
95100
let mut f = File::create(&path)?;
96-
proc.smudge_file_from_pointer(&pointer_file, &mut f, None).await?;
101+
proc.smudge_file_from_pointer(&pointer_file, &mut f, None, endpoint, token).await?;
97102
Ok(pointer_file.path().to_string())
98103
}
99104

@@ -122,7 +127,7 @@ mod tests {
122127
let pointers = vec![
123128
PointerFile::init_from_info("/tmp/foo.rs", "6999733a46030e67f6f020651c91442ace735572458573df599106e54646867c", 4203),
124129
];
125-
let paths = download_async(pointers).await.unwrap();
130+
let paths = download_async(pointers, "http://localhost:8080", "12345").await.unwrap();
126131
println!("paths: {paths:?}");
127132
}
128133
}

0 commit comments

Comments
 (0)