Skip to content

Commit ffa9faa

Browse files
authored
Adds User-Agent header in cas_client requests to CAS (#546)
Adds User-Agent when making requests to CAS. * sets to (project) / (version) * version is picked from Cargo.toml * project is hf-xet crates, git-xet crates, (also hard-coded xtool) The reason for this change is to add better observability on the server - so we can segment reqeusts by client and understand client versions in the wild.
1 parent aac19ee commit ffa9faa

File tree

16 files changed

+117
-62
lines changed

16 files changed

+117
-62
lines changed

cas_client/src/download_utils.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ mod tests {
635635
MerkleHash::default(),
636636
file_range,
637637
server.base_url(),
638-
Arc::new(build_http_client(RetryConfig::default(), "")?),
638+
Arc::new(build_http_client(RetryConfig::default(), "", "")?),
639639
);
640640

641641
fetch_info.query().await?;
@@ -682,7 +682,7 @@ mod tests {
682682
MerkleHash::default(),
683683
file_range_to_refresh,
684684
server.base_url(),
685-
Arc::new(build_http_client(RetryConfig::default(), "")?),
685+
Arc::new(build_http_client(RetryConfig::default(), "", "")?),
686686
));
687687

688688
// Spawn multiple tasks each calling into refresh with a different delay in
@@ -751,7 +751,7 @@ mod tests {
751751
MerkleHash::default(),
752752
file_range,
753753
server.base_url(),
754-
Arc::new(build_http_client(RetryConfig::default(), "")?),
754+
Arc::new(build_http_client(RetryConfig::default(), "", "")?),
755755
);
756756

757757
let (offset_info_first_range, terms) = fetch_info.query().await?.unwrap();
@@ -762,7 +762,7 @@ mod tests {
762762
range: x1range[0].range,
763763
fetch_info: Arc::new(fetch_info),
764764
chunk_cache: None,
765-
client: Arc::new(build_http_client(RetryConfig::default(), "")?),
765+
client: Arc::new(build_http_client(RetryConfig::default(), "", "")?),
766766
range_download_single_flight: Arc::new(Group::new()),
767767
},
768768
term: terms[0].clone(),

cas_client/src/http_client.rs

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,29 +88,40 @@ impl RetryConfig<No429RetryStrategy> {
8888
}
8989
}
9090

91-
fn reqwest_client() -> Result<reqwest::Client, CasClientError> {
91+
fn reqwest_client(user_agent: &str) -> Result<reqwest::Client, CasClientError> {
9292
// custom dns resolver not supported in WASM. no access to getaddrinfo/any other dns interface.
9393
#[cfg(target_family = "wasm")]
9494
{
95-
static CLIENT: std::sync::LazyLock<reqwest::Client> = std::sync::LazyLock::new(|| reqwest::Client::new());
96-
Ok((&*CLIENT).clone())
95+
// For WASM, create a new client with the specified user_agent
96+
// Note: we could cache this, but user_agent can vary, so we create per-call
97+
let mut builder = reqwest::Client::builder();
98+
if !user_agent.is_empty() {
99+
builder = builder.user_agent(user_agent);
100+
}
101+
Ok(builder.build()?)
97102
}
98103

99104
#[cfg(not(target_family = "wasm"))]
100105
{
101106
use xet_runtime::XetRuntime;
102107

103108
let client = XetRuntime::get_or_create_reqwest_client(|| {
104-
reqwest::Client::builder()
109+
let mut builder = reqwest::Client::builder()
105110
.pool_idle_timeout(*CLIENT_IDLE_CONNECTION_TIMEOUT)
106111
.pool_max_idle_per_host(*CLIENT_MAX_IDLE_CONNECTIONS)
107-
.http1_only() // high throughput parallel I/O has been shown to bottleneck with http2
108-
.build()
112+
.http1_only(); // high throughput parallel I/O has been shown to bottleneck with http2
113+
114+
if !user_agent.is_empty() {
115+
builder = builder.user_agent(user_agent);
116+
}
117+
118+
builder.build()
109119
})?;
110120

111121
info!(
112122
idle_timeout=?*CLIENT_IDLE_CONNECTION_TIMEOUT,
113123
max_idle_connections=*CLIENT_MAX_IDLE_CONNECTIONS,
124+
user_agent=?if user_agent.is_empty() { None } else { Some(user_agent) },
114125
"HTTP client configured"
115126
);
116127

@@ -124,12 +135,13 @@ pub fn build_auth_http_client<R: RetryableStrategy + Send + Sync + 'static>(
124135
auth_config: &Option<AuthConfig>,
125136
retry_config: RetryConfig<R>,
126137
session_id: &str,
138+
user_agent: &str,
127139
) -> Result<ClientWithMiddleware, CasClientError> {
128140
let auth_middleware = auth_config.as_ref().map(AuthMiddleware::from);
129141
let logging_middleware = Some(LoggingMiddleware);
130142
let session_middleware = (!session_id.is_empty()).then(|| SessionMiddleware(session_id.to_owned()));
131143

132-
let client = ClientBuilder::new(reqwest_client()?)
144+
let client = ClientBuilder::new(reqwest_client(user_agent)?)
133145
.maybe_with(auth_middleware)
134146
.with(get_retry_middleware(retry_config))
135147
.maybe_with(logging_middleware)
@@ -142,11 +154,12 @@ pub fn build_auth_http_client<R: RetryableStrategy + Send + Sync + 'static>(
142154
pub fn build_auth_http_client_no_retry(
143155
auth_config: &Option<AuthConfig>,
144156
session_id: &str,
157+
user_agent: &str,
145158
) -> Result<ClientWithMiddleware, CasClientError> {
146159
let auth_middleware = auth_config.as_ref().map(AuthMiddleware::from).info_none("CAS auth disabled");
147160
let logging_middleware = Some(LoggingMiddleware);
148161
let session_middleware = (!session_id.is_empty()).then(|| SessionMiddleware(session_id.to_owned()));
149-
Ok(ClientBuilder::new(reqwest_client()?)
162+
Ok(ClientBuilder::new(reqwest_client(user_agent)?)
150163
.maybe_with(auth_middleware)
151164
.maybe_with(logging_middleware)
152165
.maybe_with(session_middleware)
@@ -158,14 +171,15 @@ pub fn build_auth_http_client_no_retry(
158171
pub fn build_http_client<R: RetryableStrategy + Send + Sync + 'static>(
159172
retry_config: RetryConfig<R>,
160173
session_id: &str,
174+
user_agent: &str,
161175
) -> Result<ClientWithMiddleware, CasClientError> {
162-
build_auth_http_client(&None, retry_config, session_id)
176+
build_auth_http_client(&None, retry_config, session_id, user_agent)
163177
}
164178

165179
/// Builds HTTP Client to talk to CAS.
166180
/// Includes retry middleware with exponential backoff.
167-
pub fn build_http_client_no_retry(session_id: &str) -> Result<ClientWithMiddleware, CasClientError> {
168-
build_auth_http_client_no_retry(&None, session_id)
181+
pub fn build_http_client_no_retry(session_id: &str, user_agent: &str) -> Result<ClientWithMiddleware, CasClientError> {
182+
build_auth_http_client_no_retry(&None, session_id, user_agent)
169183
}
170184

171185
/// RetryStrategy
@@ -386,7 +400,7 @@ mod tests {
386400
max_retry_interval_ms: 3000,
387401
strategy: DefaultRetryableStrategy,
388402
};
389-
let client = build_auth_http_client(&None, retry_config, "").unwrap();
403+
let client = build_auth_http_client(&None, retry_config, "", "").unwrap();
390404

391405
// Act & Assert - should retry and log
392406
let response = client.get(server.url("/data")).send().await.unwrap();
@@ -412,7 +426,7 @@ mod tests {
412426
max_retry_interval_ms: 3000,
413427
strategy: No429RetryStrategy,
414428
};
415-
let client = build_auth_http_client(&None, retry_config, "").unwrap();
429+
let client = build_auth_http_client(&None, retry_config, "", "").unwrap();
416430

417431
// Act & Assert - should retry and log
418432
let response = client.get(server.url("/data")).send().await.unwrap();
@@ -442,7 +456,7 @@ mod tests {
442456
max_retry_interval_ms: 3000,
443457
strategy: DefaultRetryableStrategy,
444458
};
445-
let client = build_auth_http_client(&None, retry_config, "").unwrap();
459+
let client = build_auth_http_client(&None, retry_config, "", "").unwrap();
446460

447461
// Act & Assert - should retry and log
448462
let response = client.get(server.url("/data")).send().await.unwrap();
@@ -468,7 +482,7 @@ mod tests {
468482
max_retry_interval_ms: 3000,
469483
strategy: No429RetryStrategy,
470484
};
471-
let client = build_auth_http_client(&None, retry_config, "").unwrap();
485+
let client = build_auth_http_client(&None, retry_config, "", "").unwrap();
472486

473487
// Act & Assert - should retry and log
474488
let response = client.get(server.url("/data")).send().await.unwrap();
@@ -499,7 +513,7 @@ mod tests {
499513
max_retry_interval_ms: 6000,
500514
strategy: DefaultRetryableStrategy,
501515
};
502-
let client = build_auth_http_client(&None, retry_config, "").unwrap();
516+
let client = build_auth_http_client(&None, retry_config, "", "").unwrap();
503517

504518
// Act & Assert - should retry and log
505519
let response = client.get(server.url("/data")).send().await.unwrap();
@@ -527,7 +541,7 @@ mod tests {
527541
max_retry_interval_ms: 6000,
528542
strategy: No429RetryStrategy,
529543
};
530-
let client = build_auth_http_client(&None, retry_config, "").unwrap();
544+
let client = build_auth_http_client(&None, retry_config, "", "").unwrap();
531545

532546
// Act & Assert - should retry and log
533547
let response = client.get(server.url("/data")).send().await.unwrap();
@@ -557,7 +571,7 @@ mod tests {
557571
max_retry_interval_ms: 6000,
558572
strategy: No429RetryStrategy,
559573
};
560-
let client = build_auth_http_client(&None, retry_config, "").unwrap();
574+
let client = build_auth_http_client(&None, retry_config, "", "").unwrap();
561575

562576
// Act & Assert - should retry and log
563577
let response = client.get(server.url("/data")).send().await.unwrap();

cas_client/src/remote_client.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ impl RemoteClient {
209209
shard_cache_directory: Option<PathBuf>,
210210
session_id: &str,
211211
dry_run: bool,
212+
user_agent: &str,
212213
) -> Self {
213214
// use disk cache if cache_config provided.
214215
let chunk_cache = if let Some(cache_config) = cache_config {
@@ -229,13 +230,13 @@ impl RemoteClient {
229230
endpoint: endpoint.to_string(),
230231
dry_run,
231232
authenticated_http_client_with_retry: Arc::new(
232-
http_client::build_auth_http_client(auth, RetryConfig::default(), session_id).unwrap(),
233+
http_client::build_auth_http_client(auth, RetryConfig::default(), session_id, user_agent).unwrap(),
233234
),
234235
authenticated_http_client: Arc::new(
235-
http_client::build_auth_http_client_no_retry(auth, session_id).unwrap(),
236+
http_client::build_auth_http_client_no_retry(auth, session_id, user_agent).unwrap(),
236237
),
237238
http_client_with_retry: Arc::new(
238-
http_client::build_http_client(RetryConfig::default(), session_id).unwrap(),
239+
http_client::build_http_client(RetryConfig::default(), session_id, user_agent).unwrap(),
239240
),
240241
chunk_cache,
241242
#[cfg(not(target_family = "wasm"))]
@@ -941,7 +942,7 @@ mod tests {
941942
let raw_xorb = build_raw_xorb(3, ChunkSize::Random(512, 10248));
942943

943944
let threadpool = XetRuntime::new().unwrap();
944-
let client = RemoteClient::new(CAS_ENDPOINT, &None, &None, None, "", false);
945+
let client = RemoteClient::new(CAS_ENDPOINT, &None, &None, None, "", false, "");
945946

946947
let cas_object = build_and_verify_cas_object(raw_xorb, Some(CompressionScheme::LZ4));
947948

@@ -1315,7 +1316,7 @@ mod tests {
13151316

13161317
// test reconstruct and sequential write
13171318
let test = test_case.clone();
1318-
let client = RemoteClient::new(endpoint, &None, &None, None, "", false);
1319+
let client = RemoteClient::new(endpoint, &None, &None, None, "", false, "");
13191320
let buf = ThreadSafeBuffer::default();
13201321
let provider = SequentialOutput::from(buf.clone());
13211322
let resp = threadpool.external_run_async_task(async move {
@@ -1337,7 +1338,7 @@ mod tests {
13371338

13381339
// test reconstruct and parallel write
13391340
let test = test_case;
1340-
let client = RemoteClient::new(endpoint, &None, &None, None, "", false);
1341+
let client = RemoteClient::new(endpoint, &None, &None, None, "", false, "");
13411342
let buf = ThreadSafeBuffer::default();
13421343
let provider = SeekingOutputProvider::from(buf.clone());
13431344
let resp = threadpool.external_run_async_task(async move {

data/src/bin/xtool.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use walkdir::WalkDir;
1818
use xet_runtime::XetRuntime;
1919

2020
const DEFAULT_HF_ENDPOINT: &str = "https://huggingface.co";
21+
const USER_AGENT: &str = concat!("xtool", "/", env!("CARGO_PKG_VERSION"));
2122

2223
#[derive(Parser)]
2324
struct XCommand {
@@ -60,7 +61,7 @@ impl XCommand {
6061
&endpoint,
6162
RepoInfo::try_from(&self.overrides.repo_type, &self.overrides.repo_id)?,
6263
Some("main".to_owned()),
63-
"xtool",
64+
USER_AGENT,
6465
"",
6566
cred_helper,
6667
)?;
@@ -209,6 +210,7 @@ async fn query_reconstruction(
209210
None,
210211
Some((jwt_info.access_token, jwt_info.exp)),
211212
Some(token_refresher),
213+
USER_AGENT.to_string(),
212214
)?;
213215
let cas_storage_config = &config.data_config;
214216
let remote_client = RemoteClient::new(
@@ -218,6 +220,7 @@ async fn query_reconstruction(
218220
Some(config.shard_config.cache_directory.clone()),
219221
"",
220222
true,
223+
&cas_storage_config.user_agent,
221224
);
222225

223226
remote_client

data/src/configurations.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub struct DataConfig {
2323
pub prefix: String,
2424
pub cache_config: CacheConfig,
2525
pub staging_directory: Option<PathBuf>,
26+
pub user_agent: String,
2627
}
2728

2829
#[derive(Debug)]
@@ -98,6 +99,7 @@ impl TranslatorConfig {
9899
cache_size: *CHUNK_CACHE_SIZE_BYTES,
99100
},
100101
staging_directory: None,
102+
user_agent: String::new(),
101103
},
102104
shard_config: ShardConfig {
103105
prefix: PREFIX_DEFAULT.into(),

0 commit comments

Comments
 (0)