Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions cas_client/src/download_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,40 @@ pub(crate) struct FetchTermDownload {
pub range_download_single_flight: RangeDownloadSingleFlight,
}

#[derive(Debug)]
#[derive(Derivative)]
#[derivative(Debug)]
pub(crate) struct SequentialTermDownload {
pub term: CASReconstructionTerm,
pub download: FetchTermDownload,
pub skip_bytes: u64, // number of bytes to skip at the front
pub take: u64, /* number of bytes to take after skipping bytes,
* effectively taking [skip_bytes..skip_bytes+take]
* out of the downloaded range */
#[derivative(Debug = "ignore")]
pub coalesced_range_reuse_cache: Arc<dyn ChunkCache>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish there was a way to use 1 cache passed in here, there's one in FetchTermDownload

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely a trade-off. Using a disk cache brings down memory usage but as Adrien and Corentin mentioned using a disk cache slows down their environment a LOT. I don't think there is a universal solution for all environments and this approach is more tailored for their usage, meaning, I don't think using an in-memory cache is suitable for all hf-xet users either.

}

impl SequentialTermDownload {
pub async fn run(self) -> Result<TermDownloadResult<Vec<u8>>> {
// First try from the coalesced_range_reuse_cache
let key = Key {
prefix: PREFIX_DEFAULT.into(),
hash: self.download.hash,
};

if let Ok(Some(cache_item)) = self.coalesced_range_reuse_cache.get(&key, &self.term.range).await {
// extract just the actual range data out of the term download output
let start = self.skip_bytes as usize;
let end = start + self.take as usize;
let final_term_data = &cache_item.data[start..end];

return Ok(TermDownloadResult {
payload: final_term_data.to_vec(),
duration: Duration::from_secs(0),
n_retries_on_403: 0,
});
}

let TermDownloadResult {
payload:
TermDownloadOutput {
Expand All @@ -196,8 +218,15 @@ impl SequentialTermDownload {
} = self.download.run().await?;

// if the requested range is smaller than the fetched range, trim it down to the right data
// the requested range cannot be larger than the fetched range.
// the requested range cannot be larger than the fetched range, and cache the fetched range
// because it will be reused within the segment.
// "else" case data matches exact, save some work, return whole data.
if self.term.range.start != chunk_range.start || self.term.range.end != chunk_range.end {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this condition captures some of the cases but not all. this helps with the case we're dealing but it's not general for everything.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but I think it's a pretty solid heuristic here -- the only case it doesn't capture is when the downloaded range is referenced in full in this situation and then partially referenced later. But I think for now, best-effort is probably fine.

self.coalesced_range_reuse_cache
.put(&key, &chunk_range, &chunk_byte_indices, &data)
.await?;
}

let start_idx = (self.term.range.start - chunk_range.start) as usize;
let end_idx = (self.term.range.end - chunk_range.start) as usize;

Expand Down Expand Up @@ -574,6 +603,7 @@ async fn download_fetch_term_data(
mod tests {
use anyhow::Result;
use cas_types::{HttpRange, QueryReconstructionResponse};
use chunk_cache::MemoryCache;
use http::header::RANGE;
use httpmock::prelude::*;
use tokio::task::JoinSet;
Expand Down Expand Up @@ -768,6 +798,7 @@ mod tests {
term: terms[0].clone(),
skip_bytes: offset_info_first_range,
take: file_range.length(),
coalesced_range_reuse_cache: Arc::new(MemoryCache::default()),
};

let handle = tokio::spawn(async move { download_task.run().await });
Expand Down
7 changes: 7 additions & 0 deletions cas_client/src/remote_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ impl RemoteClient {
},
DownloadQueueItem::Metadata(fetch_info) => {
// query for the file info of the first segment

use chunk_cache::MemoryCache;
let segment_size = download_scheduler_clone.next_segment_size()?;
debug!(call_id, segment_size, "querying file info");
let (segment, maybe_remainder) = fetch_info.take_segment(segment_size);
Expand All @@ -427,6 +429,10 @@ impl RemoteClient {
// define the term download tasks
let mut remaining_segment_len = segment_size;
debug!(call_id, num_tasks = terms.len(), "enqueueing download tasks");

// in-memory cache for this segment
let coalesced_range_reuse_cache = Arc::new(MemoryCache::default());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like @hoytak I'm worried this could suck up too much RAM since it can keep growing.

I think capping the overall byte size of the cache would be good - maybe the default is 1GB but configurable by env variable.

Instead of random eviction, do we want to evict based on min heap of num_bytes_in_segment? Meaning, let's evict the smallest segments since that would minimize the number of network requests.

Also, think we should probably log a bit more so we know how many times a segment is reused. That is likely nice-to-have.

Copy link
Collaborator

@hoytak hoytak Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think evicting based on num_bytes_in_segment would work. Then the cache will fill up with large chunks and then any small chunks in it will get immediately evicted, before they would be used.

A better heuristic here perhaps is insertion time. However, I think the memory issues can possibly be resolved by having a common cache with the larger limit instead of the per-download cache. Or simply changing that default to be smaller, reflecting the limit on the number of parallel downloads.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the main optimization goal leading to adding a memory-cache is to limit the number of network requests to reconstruct a file? By minimizing this we reduce overall download time and we ensure we don't download more bytes than we need (repeatedly downloading the same segment).

So by having an upper bound for cache size we need to evict entries. The entries we should evict should be the least used & smallest segments. These segments being evicted will result in the least number of network requests.

Since we get all the reconstruction info at the beginning of the download we can do the bookkeeping to see which segments have the most terms and maybe weight those segments based on overall num bytes when inserting into cache. Then as we consume from cache we can decrement remaining terms from segment and remove from cache when segment is fully consumed. When evicting, we evict the segments with the least usage + least bytes -> preserving the biggest segments with the most remaining terms.

Yes, this will result in the cache being filled up with large chunks but those large chunks also have lots of remaining terms to be consumed so worth keeping them. And the large chunks will be removed from cache once all their terms are consumed anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The map would be something like:

(xorb_id : ( num_terms_remaining, weighted_segment_score, bytes ))
Weighted segment score = ( (segment_bytes / HF_XET_NUM_RANGE_IN_SEGMENT_BASE) / num_terms

This means segment_bytes as % of configurable segment size divided by number of terms it satisfies.

The smallest of these weighted scores should be evicted.


This way on a cache hit once num_terms_remaining hits 0 the entire cache entry is removed.


for (i, term) in terms.into_iter().enumerate() {
let skip_bytes = if i == 0 { offset_into_first_range } else { 0 };
let take = remaining_total_len
Expand All @@ -446,6 +452,7 @@ impl RemoteClient {
term,
skip_bytes,
take,
coalesced_range_reuse_cache: coalesced_range_reuse_cache.clone(),
};

remaining_total_len -= take;
Expand Down
2 changes: 2 additions & 0 deletions chunk_cache/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod cache_manager;
mod disk;
pub mod error;
mod memory;

use std::path::PathBuf;

Expand All @@ -10,6 +11,7 @@ use cas_types::{ChunkRange, Key};
pub use disk::DiskCache;
pub use disk::test_utils::*;
use error::ChunkCacheError;
pub use memory::MemoryCache;
use mockall::automock;

pub use crate::disk::DEFAULT_CHUNK_CACHE_CAPACITY;
Expand Down
251 changes: 251 additions & 0 deletions chunk_cache/src/memory.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no evictions, this could get large, I think it's fair to assume for fragmented files that we will have all of the ranges in here.

Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use cas_types::{ChunkRange, Key};
use tokio::sync::RwLock;

use crate::error::ChunkCacheError;
use crate::{CacheRange, ChunkCache};

#[derive(Debug, Clone)]
struct MemoryCacheItem {
range: ChunkRange,
chunk_byte_indices: Vec<u32>,
data: Vec<u8>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use Bytes here?

Copy link
Collaborator Author

@seanses seanses Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be the benefit of Bytes to Vec w.r.t. implementing this ChunkCache trait?

}

#[derive(Debug, Clone, Default)]
struct CacheState {
inner: HashMap<Key, Vec<MemoryCacheItem>>,
num_items: usize,
total_bytes: u64,
}

impl CacheState {
fn find_match(&self, key: &Key, range: &ChunkRange) -> Option<&MemoryCacheItem> {
let items = self.inner.get(key)?;

items
.iter()
.find(|&item| item.range.start <= range.start && range.end <= item.range.end)
.map(|v| v as _)
}
}

/// MemoryCache is a ChunkCache implementor that stores data in memory
#[derive(Debug, Clone, Default)]
pub struct MemoryCache {
state: Arc<RwLock<CacheState>>,
}

impl MemoryCache {
pub async fn num_items(&self) -> usize {
self.state.read().await.num_items
}

pub async fn total_bytes(&self) -> u64 {
self.state.read().await.total_bytes
}
}

#[async_trait]
impl ChunkCache for MemoryCache {
async fn get(&self, key: &Key, range: &ChunkRange) -> Result<Option<CacheRange>, ChunkCacheError> {
if range.start >= range.end {
return Err(ChunkCacheError::InvalidArguments);
}

let state = self.state.read().await;
let Some(cache_item) = state.find_match(key, range) else {
return Ok(None);
};

// Extract the requested range from the cached item
let start_idx = (range.start - cache_item.range.start) as usize;
let end_idx = (range.end - cache_item.range.start) as usize;

if end_idx >= cache_item.chunk_byte_indices.len() {
return Err(ChunkCacheError::BadRange);
}

let start_byte = cache_item.chunk_byte_indices[start_idx] as usize;
let end_byte = cache_item.chunk_byte_indices[end_idx] as usize;

if end_byte > cache_item.data.len() {
return Err(ChunkCacheError::BadRange);
}

let data = cache_item.data[start_byte..end_byte].to_vec();
let offsets: Vec<u32> = cache_item.chunk_byte_indices[start_idx..=end_idx]
.iter()
.map(|v| v - cache_item.chunk_byte_indices[start_idx])
.collect();

Ok(Some(CacheRange {
offsets,
data,
range: *range,
}))
}

async fn put(
&self,
key: &Key,
range: &ChunkRange,
chunk_byte_indices: &[u32],
data: &[u8],
) -> Result<(), ChunkCacheError> {
// Validate inputs
if range.start >= range.end
|| chunk_byte_indices.len() != (range.end - range.start + 1) as usize
|| chunk_byte_indices.is_empty()
|| chunk_byte_indices[0] != 0
|| *chunk_byte_indices.last().unwrap() as usize != data.len()
|| !strictly_increasing(chunk_byte_indices)
{
return Err(ChunkCacheError::InvalidArguments);
}

let data_len = data.len() as u64;

let mut state = self.state.write().await;

// Check if we already have this exact range cached
if let Some(items) = state.inner.get(key) {
for item in items.iter() {
if item.range == *range {
// Already cached
return Ok(());
}
}
}

// Add the new item
let cache_item = MemoryCacheItem {
range: *range,
chunk_byte_indices: chunk_byte_indices.to_vec(),
data: data.to_vec(),
};

state.total_bytes += data_len;
state.num_items += 1;

state.inner.entry(key.clone()).or_insert_with(Vec::new).push(cache_item);

Ok(())
}
}

fn strictly_increasing(chunk_byte_indices: &[u32]) -> bool {
for i in 1..chunk_byte_indices.len() {
if chunk_byte_indices[i - 1] >= chunk_byte_indices[i] {
return false;
}
}
true
}

#[cfg(test)]
mod tests {
use cas_types::{ChunkRange, Key};

use super::*;

#[tokio::test]
async fn test_memory_cache_basic() {
let cache = MemoryCache::default();
let key = Key {
prefix: "test".to_string(),
hash: merklehash::MerkleHash::default(),
};
let range = ChunkRange::new(0, 10);
let chunk_byte_indices: Vec<u32> = (0..=10).map(|i| i * 100).collect();
let data = vec![0u8; 1000];

// Put data in cache
cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap();

// Get data from cache
let result = cache.get(&key, &range).await.unwrap();
assert!(result.is_some());
let cache_range = result.unwrap();
assert_eq!(cache_range.data, data);
assert_eq!(cache_range.range, range);
}

#[tokio::test]
async fn test_memory_cache_invalid_inputs() {
let cache = MemoryCache::default();
let key = Key {
prefix: "test".to_string(),
hash: merklehash::MerkleHash::default(),
};

// Test invalid range (start >= end)
let invalid_range = ChunkRange::new(10, 5);
let data = vec![0u8; 100];
let chunk_byte_indices = vec![0, 50, 100];
assert!(cache.put(&key, &invalid_range, &chunk_byte_indices, &data).await.is_err());

// Test non-increasing chunk indices
let range = ChunkRange::new(0, 2);
let invalid_indices = vec![0, 50, 20];
assert!(cache.put(&key, &range, &invalid_indices, &data).await.is_err());

// Test empty indices
let empty_indices: Vec<u32> = vec![];
assert!(cache.put(&key, &range, &empty_indices, &data).await.is_err());

// Test non-zero first index
let invalid_start = vec![10, 50, 100];
assert!(cache.put(&key, &range, &invalid_start, &data).await.is_err());

// Test last index != data length
let invalid_end = vec![0, 50, 80];
assert!(cache.put(&key, &range, &invalid_end, &data).await.is_err());
}

#[tokio::test]
async fn test_memory_cache_duplicate_put() {
let cache = MemoryCache::default();
let key = Key {
prefix: "test".to_string(),
hash: merklehash::MerkleHash::default(),
};
let range = ChunkRange::new(0, 2);
let chunk_byte_indices = vec![0, 50, 100];
let data = vec![0u8; 100];

// First put should succeed
cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap();
let initial_items = cache.num_items().await;
let initial_bytes = cache.total_bytes().await;

// Second put of same range should not increase counters
cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap();
assert_eq!(cache.num_items().await, initial_items);
assert_eq!(cache.total_bytes().await, initial_bytes);
}

#[tokio::test]
async fn test_memory_cache_partial_range() {
let cache = MemoryCache::default();
let key = Key {
prefix: "test".to_string(),
hash: merklehash::MerkleHash::default(),
};
let range = ChunkRange::new(0, 2);
let chunk_byte_indices = vec![0, 50, 100];
let data = vec![1u8; 100];

cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap();

// Get partial range
let partial_range = ChunkRange::new(1, 2);
let result = cache.get(&key, &partial_range).await.unwrap().unwrap();
assert_eq!(result.range, partial_range);
assert_eq!(result.data.len(), 50);
assert_eq!(result.offsets, vec![0, 50]);
}
}
Loading