-
Notifications
You must be signed in to change notification settings - Fork 38
Sequential download uses an in-memory cache for coalesced range within a segment #558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -172,18 +172,40 @@ pub(crate) struct FetchTermDownload { | |
| pub range_download_single_flight: RangeDownloadSingleFlight, | ||
| } | ||
|
|
||
| #[derive(Debug)] | ||
| #[derive(Derivative)] | ||
| #[derivative(Debug)] | ||
| pub(crate) struct SequentialTermDownload { | ||
| pub term: CASReconstructionTerm, | ||
| pub download: FetchTermDownload, | ||
| pub skip_bytes: u64, // number of bytes to skip at the front | ||
| pub take: u64, /* number of bytes to take after skipping bytes, | ||
| * effectively taking [skip_bytes..skip_bytes+take] | ||
| * out of the downloaded range */ | ||
| #[derivative(Debug = "ignore")] | ||
| pub coalesced_range_reuse_cache: Arc<dyn ChunkCache>, | ||
| } | ||
|
|
||
| impl SequentialTermDownload { | ||
| pub async fn run(self) -> Result<TermDownloadResult<Vec<u8>>> { | ||
| // First try from the coalesced_range_reuse_cache | ||
| let key = Key { | ||
| prefix: PREFIX_DEFAULT.into(), | ||
| hash: self.download.hash, | ||
| }; | ||
|
|
||
| if let Ok(Some(cache_item)) = self.coalesced_range_reuse_cache.get(&key, &self.term.range).await { | ||
| // extract just the actual range data out of the term download output | ||
| let start = self.skip_bytes as usize; | ||
| let end = start + self.take as usize; | ||
| let final_term_data = &cache_item.data[start..end]; | ||
|
|
||
| return Ok(TermDownloadResult { | ||
| payload: final_term_data.to_vec(), | ||
| duration: Duration::from_secs(0), | ||
| n_retries_on_403: 0, | ||
| }); | ||
| } | ||
|
|
||
| let TermDownloadResult { | ||
| payload: | ||
| TermDownloadOutput { | ||
|
|
@@ -196,8 +218,15 @@ impl SequentialTermDownload { | |
| } = self.download.run().await?; | ||
|
|
||
| // if the requested range is smaller than the fetched range, trim it down to the right data | ||
| // the requested range cannot be larger than the fetched range. | ||
| // the requested range cannot be larger than the fetched range, and cache the fetched range | ||
| // because it will be reused within the segment. | ||
| // "else" case data matches exact, save some work, return whole data. | ||
| if self.term.range.start != chunk_range.start || self.term.range.end != chunk_range.end { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this condition captures some of the cases but not all. this helps with the case we're dealing but it's not general for everything. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, but I think it's a pretty solid heuristic here -- the only case it doesn't capture is when the downloaded range is referenced in full in this situation and then partially referenced later. But I think for now, best-effort is probably fine. |
||
| self.coalesced_range_reuse_cache | ||
| .put(&key, &chunk_range, &chunk_byte_indices, &data) | ||
| .await?; | ||
| } | ||
|
|
||
| let start_idx = (self.term.range.start - chunk_range.start) as usize; | ||
| let end_idx = (self.term.range.end - chunk_range.start) as usize; | ||
|
|
||
|
|
@@ -574,6 +603,7 @@ async fn download_fetch_term_data( | |
| mod tests { | ||
| use anyhow::Result; | ||
| use cas_types::{HttpRange, QueryReconstructionResponse}; | ||
| use chunk_cache::MemoryCache; | ||
| use http::header::RANGE; | ||
| use httpmock::prelude::*; | ||
| use tokio::task::JoinSet; | ||
|
|
@@ -768,6 +798,7 @@ mod tests { | |
| term: terms[0].clone(), | ||
| skip_bytes: offset_info_first_range, | ||
| take: file_range.length(), | ||
| coalesced_range_reuse_cache: Arc::new(MemoryCache::default()), | ||
| }; | ||
|
|
||
| let handle = tokio::spawn(async move { download_task.run().await }); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -413,6 +413,8 @@ impl RemoteClient { | |
| }, | ||
| DownloadQueueItem::Metadata(fetch_info) => { | ||
| // query for the file info of the first segment | ||
|
|
||
| use chunk_cache::MemoryCache; | ||
| let segment_size = download_scheduler_clone.next_segment_size()?; | ||
| debug!(call_id, segment_size, "querying file info"); | ||
| let (segment, maybe_remainder) = fetch_info.take_segment(segment_size); | ||
|
|
@@ -427,6 +429,10 @@ impl RemoteClient { | |
| // define the term download tasks | ||
| let mut remaining_segment_len = segment_size; | ||
| debug!(call_id, num_tasks = terms.len(), "enqueueing download tasks"); | ||
|
|
||
| // in-memory cache for this segment | ||
| let coalesced_range_reuse_cache = Arc::new(MemoryCache::default()); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like @hoytak I'm worried this could suck up too much RAM since it can keep growing. I think capping the overall byte size of the cache would be good - maybe the default is 1GB but configurable by env variable. Instead of random eviction, do we want to evict based on min heap of Also, think we should probably log a bit more so we know how many times a segment is reused. That is likely nice-to-have. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think evicting based on num_bytes_in_segment would work. Then the cache will fill up with large chunks and then any small chunks in it will get immediately evicted, before they would be used. A better heuristic here perhaps is insertion time. However, I think the memory issues can possibly be resolved by having a common cache with the larger limit instead of the per-download cache. Or simply changing that default to be smaller, reflecting the limit on the number of parallel downloads. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the main optimization goal leading to adding a memory-cache is to limit the number of network requests to reconstruct a file? By minimizing this we reduce overall download time and we ensure we don't download more bytes than we need (repeatedly downloading the same segment). So by having an upper bound for cache size we need to evict entries. The entries we should evict should be the least used & smallest segments. These segments being evicted will result in the least number of network requests. Since we get all the reconstruction info at the beginning of the download we can do the bookkeeping to see which segments have the most terms and maybe weight those segments based on overall num bytes when inserting into cache. Then as we consume from cache we can decrement remaining terms from segment and remove from cache when segment is fully consumed. When evicting, we evict the segments with the least usage + least bytes -> preserving the biggest segments with the most remaining terms. Yes, this will result in the cache being filled up with large chunks but those large chunks also have lots of remaining terms to be consumed so worth keeping them. And the large chunks will be removed from cache once all their terms are consumed anyway. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The map would be something like: This means segment_bytes as % of configurable segment size divided by number of terms it satisfies. The smallest of these weighted scores should be evicted. This way on a cache hit once num_terms_remaining hits 0 the entire cache entry is removed. |
||
|
|
||
| for (i, term) in terms.into_iter().enumerate() { | ||
| let skip_bytes = if i == 0 { offset_into_first_range } else { 0 }; | ||
| let take = remaining_total_len | ||
|
|
@@ -446,6 +452,7 @@ impl RemoteClient { | |
| term, | ||
| skip_bytes, | ||
| take, | ||
| coalesced_range_reuse_cache: coalesced_range_reuse_cache.clone(), | ||
| }; | ||
|
|
||
| remaining_total_len -= take; | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no evictions, this could get large, I think it's fair to assume for fragmented files that we will have all of the ranges in here. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,251 @@ | ||
| use std::collections::HashMap; | ||
| use std::sync::Arc; | ||
|
|
||
| use async_trait::async_trait; | ||
| use cas_types::{ChunkRange, Key}; | ||
| use tokio::sync::RwLock; | ||
|
|
||
| use crate::error::ChunkCacheError; | ||
| use crate::{CacheRange, ChunkCache}; | ||
|
|
||
| #[derive(Debug, Clone)] | ||
| struct MemoryCacheItem { | ||
| range: ChunkRange, | ||
| chunk_byte_indices: Vec<u32>, | ||
| data: Vec<u8>, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we use Bytes here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What will be the benefit of Bytes to Vec w.r.t. implementing this ChunkCache trait? |
||
| } | ||
|
|
||
| #[derive(Debug, Clone, Default)] | ||
| struct CacheState { | ||
| inner: HashMap<Key, Vec<MemoryCacheItem>>, | ||
| num_items: usize, | ||
| total_bytes: u64, | ||
| } | ||
|
|
||
| impl CacheState { | ||
| fn find_match(&self, key: &Key, range: &ChunkRange) -> Option<&MemoryCacheItem> { | ||
| let items = self.inner.get(key)?; | ||
|
|
||
| items | ||
| .iter() | ||
| .find(|&item| item.range.start <= range.start && range.end <= item.range.end) | ||
| .map(|v| v as _) | ||
| } | ||
| } | ||
|
|
||
| /// MemoryCache is a ChunkCache implementor that stores data in memory | ||
| #[derive(Debug, Clone, Default)] | ||
| pub struct MemoryCache { | ||
| state: Arc<RwLock<CacheState>>, | ||
| } | ||
|
|
||
| impl MemoryCache { | ||
| pub async fn num_items(&self) -> usize { | ||
| self.state.read().await.num_items | ||
| } | ||
|
|
||
| pub async fn total_bytes(&self) -> u64 { | ||
| self.state.read().await.total_bytes | ||
| } | ||
| } | ||
|
|
||
| #[async_trait] | ||
| impl ChunkCache for MemoryCache { | ||
| async fn get(&self, key: &Key, range: &ChunkRange) -> Result<Option<CacheRange>, ChunkCacheError> { | ||
| if range.start >= range.end { | ||
| return Err(ChunkCacheError::InvalidArguments); | ||
| } | ||
|
|
||
| let state = self.state.read().await; | ||
| let Some(cache_item) = state.find_match(key, range) else { | ||
| return Ok(None); | ||
| }; | ||
|
|
||
| // Extract the requested range from the cached item | ||
| let start_idx = (range.start - cache_item.range.start) as usize; | ||
| let end_idx = (range.end - cache_item.range.start) as usize; | ||
|
|
||
| if end_idx >= cache_item.chunk_byte_indices.len() { | ||
| return Err(ChunkCacheError::BadRange); | ||
| } | ||
|
|
||
| let start_byte = cache_item.chunk_byte_indices[start_idx] as usize; | ||
| let end_byte = cache_item.chunk_byte_indices[end_idx] as usize; | ||
|
|
||
| if end_byte > cache_item.data.len() { | ||
| return Err(ChunkCacheError::BadRange); | ||
| } | ||
|
|
||
| let data = cache_item.data[start_byte..end_byte].to_vec(); | ||
| let offsets: Vec<u32> = cache_item.chunk_byte_indices[start_idx..=end_idx] | ||
| .iter() | ||
| .map(|v| v - cache_item.chunk_byte_indices[start_idx]) | ||
| .collect(); | ||
|
|
||
| Ok(Some(CacheRange { | ||
| offsets, | ||
| data, | ||
| range: *range, | ||
| })) | ||
| } | ||
|
|
||
| async fn put( | ||
| &self, | ||
| key: &Key, | ||
| range: &ChunkRange, | ||
| chunk_byte_indices: &[u32], | ||
| data: &[u8], | ||
| ) -> Result<(), ChunkCacheError> { | ||
| // Validate inputs | ||
| if range.start >= range.end | ||
| || chunk_byte_indices.len() != (range.end - range.start + 1) as usize | ||
| || chunk_byte_indices.is_empty() | ||
| || chunk_byte_indices[0] != 0 | ||
| || *chunk_byte_indices.last().unwrap() as usize != data.len() | ||
| || !strictly_increasing(chunk_byte_indices) | ||
| { | ||
| return Err(ChunkCacheError::InvalidArguments); | ||
| } | ||
|
|
||
| let data_len = data.len() as u64; | ||
|
|
||
| let mut state = self.state.write().await; | ||
|
|
||
| // Check if we already have this exact range cached | ||
| if let Some(items) = state.inner.get(key) { | ||
| for item in items.iter() { | ||
| if item.range == *range { | ||
| // Already cached | ||
| return Ok(()); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Add the new item | ||
| let cache_item = MemoryCacheItem { | ||
| range: *range, | ||
| chunk_byte_indices: chunk_byte_indices.to_vec(), | ||
| data: data.to_vec(), | ||
| }; | ||
|
|
||
| state.total_bytes += data_len; | ||
| state.num_items += 1; | ||
|
|
||
| state.inner.entry(key.clone()).or_insert_with(Vec::new).push(cache_item); | ||
|
|
||
| Ok(()) | ||
| } | ||
| } | ||
|
|
||
| fn strictly_increasing(chunk_byte_indices: &[u32]) -> bool { | ||
| for i in 1..chunk_byte_indices.len() { | ||
| if chunk_byte_indices[i - 1] >= chunk_byte_indices[i] { | ||
| return false; | ||
| } | ||
| } | ||
| true | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use cas_types::{ChunkRange, Key}; | ||
|
|
||
| use super::*; | ||
|
|
||
| #[tokio::test] | ||
| async fn test_memory_cache_basic() { | ||
| let cache = MemoryCache::default(); | ||
| let key = Key { | ||
| prefix: "test".to_string(), | ||
| hash: merklehash::MerkleHash::default(), | ||
| }; | ||
| let range = ChunkRange::new(0, 10); | ||
| let chunk_byte_indices: Vec<u32> = (0..=10).map(|i| i * 100).collect(); | ||
| let data = vec![0u8; 1000]; | ||
|
|
||
| // Put data in cache | ||
| cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap(); | ||
|
|
||
| // Get data from cache | ||
| let result = cache.get(&key, &range).await.unwrap(); | ||
| assert!(result.is_some()); | ||
| let cache_range = result.unwrap(); | ||
| assert_eq!(cache_range.data, data); | ||
| assert_eq!(cache_range.range, range); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_memory_cache_invalid_inputs() { | ||
| let cache = MemoryCache::default(); | ||
| let key = Key { | ||
| prefix: "test".to_string(), | ||
| hash: merklehash::MerkleHash::default(), | ||
| }; | ||
|
|
||
| // Test invalid range (start >= end) | ||
| let invalid_range = ChunkRange::new(10, 5); | ||
| let data = vec![0u8; 100]; | ||
| let chunk_byte_indices = vec![0, 50, 100]; | ||
| assert!(cache.put(&key, &invalid_range, &chunk_byte_indices, &data).await.is_err()); | ||
|
|
||
| // Test non-increasing chunk indices | ||
| let range = ChunkRange::new(0, 2); | ||
| let invalid_indices = vec![0, 50, 20]; | ||
| assert!(cache.put(&key, &range, &invalid_indices, &data).await.is_err()); | ||
|
|
||
| // Test empty indices | ||
| let empty_indices: Vec<u32> = vec![]; | ||
| assert!(cache.put(&key, &range, &empty_indices, &data).await.is_err()); | ||
|
|
||
| // Test non-zero first index | ||
| let invalid_start = vec![10, 50, 100]; | ||
| assert!(cache.put(&key, &range, &invalid_start, &data).await.is_err()); | ||
|
|
||
| // Test last index != data length | ||
| let invalid_end = vec![0, 50, 80]; | ||
| assert!(cache.put(&key, &range, &invalid_end, &data).await.is_err()); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_memory_cache_duplicate_put() { | ||
| let cache = MemoryCache::default(); | ||
| let key = Key { | ||
| prefix: "test".to_string(), | ||
| hash: merklehash::MerkleHash::default(), | ||
| }; | ||
| let range = ChunkRange::new(0, 2); | ||
| let chunk_byte_indices = vec![0, 50, 100]; | ||
| let data = vec![0u8; 100]; | ||
|
|
||
| // First put should succeed | ||
| cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap(); | ||
| let initial_items = cache.num_items().await; | ||
| let initial_bytes = cache.total_bytes().await; | ||
|
|
||
| // Second put of same range should not increase counters | ||
| cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap(); | ||
| assert_eq!(cache.num_items().await, initial_items); | ||
| assert_eq!(cache.total_bytes().await, initial_bytes); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_memory_cache_partial_range() { | ||
| let cache = MemoryCache::default(); | ||
| let key = Key { | ||
| prefix: "test".to_string(), | ||
| hash: merklehash::MerkleHash::default(), | ||
| }; | ||
| let range = ChunkRange::new(0, 2); | ||
| let chunk_byte_indices = vec![0, 50, 100]; | ||
| let data = vec![1u8; 100]; | ||
|
|
||
| cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap(); | ||
|
|
||
| // Get partial range | ||
| let partial_range = ChunkRange::new(1, 2); | ||
| let result = cache.get(&key, &partial_range).await.unwrap().unwrap(); | ||
| assert_eq!(result.range, partial_range); | ||
| assert_eq!(result.data.len(), 50); | ||
| assert_eq!(result.offsets, vec![0, 50]); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish there was a way to use 1 cache passed in here, there's one in FetchTermDownload
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is definitely a trade-off. Using a disk cache brings down memory usage but as Adrien and Corentin mentioned using a disk cache slows down their environment a LOT. I don't think there is a universal solution for all environments and this approach is more tailored for their usage, meaning, I don't think using an in-memory cache is suitable for all hf-xet users either.