diff --git a/cas_client/src/download_utils.rs b/cas_client/src/download_utils.rs index 9fddfc82..20fbd872 100644 --- a/cas_client/src/download_utils.rs +++ b/cas_client/src/download_utils.rs @@ -172,7 +172,8 @@ pub(crate) struct FetchTermDownload { pub range_download_single_flight: RangeDownloadSingleFlight, } -#[derive(Debug)] +#[derive(Derivative)] +#[derivative(Debug)] pub(crate) struct SequentialTermDownload { pub term: CASReconstructionTerm, pub download: FetchTermDownload, @@ -180,10 +181,31 @@ pub(crate) struct SequentialTermDownload { pub take: u64, /* number of bytes to take after skipping bytes, * effectively taking [skip_bytes..skip_bytes+take] * out of the downloaded range */ + #[derivative(Debug = "ignore")] + pub coalesced_range_reuse_cache: Arc, } impl SequentialTermDownload { pub async fn run(self) -> Result>> { + // First try from the coalesced_range_reuse_cache + let key = Key { + prefix: PREFIX_DEFAULT.into(), + hash: self.download.hash, + }; + + if let Ok(Some(cache_item)) = self.coalesced_range_reuse_cache.get(&key, &self.term.range).await { + // extract just the actual range data out of the term download output + let start = self.skip_bytes as usize; + let end = start + self.take as usize; + let final_term_data = &cache_item.data[start..end]; + + return Ok(TermDownloadResult { + payload: final_term_data.to_vec(), + duration: Duration::from_secs(0), + n_retries_on_403: 0, + }); + } + let TermDownloadResult { payload: TermDownloadOutput { @@ -196,8 +218,15 @@ impl SequentialTermDownload { } = self.download.run().await?; // if the requested range is smaller than the fetched range, trim it down to the right data - // the requested range cannot be larger than the fetched range. + // the requested range cannot be larger than the fetched range, and cache the fetched range + // because it will be reused within the segment. // "else" case data matches exact, save some work, return whole data. + if self.term.range.start != chunk_range.start || self.term.range.end != chunk_range.end { + self.coalesced_range_reuse_cache + .put(&key, &chunk_range, &chunk_byte_indices, &data) + .await?; + } + let start_idx = (self.term.range.start - chunk_range.start) as usize; let end_idx = (self.term.range.end - chunk_range.start) as usize; @@ -574,6 +603,7 @@ async fn download_fetch_term_data( mod tests { use anyhow::Result; use cas_types::{HttpRange, QueryReconstructionResponse}; + use chunk_cache::MemoryCache; use http::header::RANGE; use httpmock::prelude::*; use tokio::task::JoinSet; @@ -768,6 +798,7 @@ mod tests { term: terms[0].clone(), skip_bytes: offset_info_first_range, take: file_range.length(), + coalesced_range_reuse_cache: Arc::new(MemoryCache::default()), }; let handle = tokio::spawn(async move { download_task.run().await }); diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 89f2e754..44b72fed 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -413,6 +413,8 @@ impl RemoteClient { }, DownloadQueueItem::Metadata(fetch_info) => { // query for the file info of the first segment + + use chunk_cache::MemoryCache; let segment_size = download_scheduler_clone.next_segment_size()?; debug!(call_id, segment_size, "querying file info"); let (segment, maybe_remainder) = fetch_info.take_segment(segment_size); @@ -427,6 +429,10 @@ impl RemoteClient { // define the term download tasks let mut remaining_segment_len = segment_size; debug!(call_id, num_tasks = terms.len(), "enqueueing download tasks"); + + // in-memory cache for this segment + let coalesced_range_reuse_cache = Arc::new(MemoryCache::default()); + for (i, term) in terms.into_iter().enumerate() { let skip_bytes = if i == 0 { offset_into_first_range } else { 0 }; let take = remaining_total_len @@ -446,6 +452,7 @@ impl RemoteClient { term, skip_bytes, take, + coalesced_range_reuse_cache: coalesced_range_reuse_cache.clone(), }; remaining_total_len -= take; diff --git a/chunk_cache/src/lib.rs b/chunk_cache/src/lib.rs index d656cfb0..3389831a 100644 --- a/chunk_cache/src/lib.rs +++ b/chunk_cache/src/lib.rs @@ -1,6 +1,7 @@ mod cache_manager; mod disk; pub mod error; +mod memory; use std::path::PathBuf; @@ -10,6 +11,7 @@ use cas_types::{ChunkRange, Key}; pub use disk::DiskCache; pub use disk::test_utils::*; use error::ChunkCacheError; +pub use memory::MemoryCache; use mockall::automock; pub use crate::disk::DEFAULT_CHUNK_CACHE_CAPACITY; diff --git a/chunk_cache/src/memory.rs b/chunk_cache/src/memory.rs new file mode 100644 index 00000000..5ca17553 --- /dev/null +++ b/chunk_cache/src/memory.rs @@ -0,0 +1,251 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use cas_types::{ChunkRange, Key}; +use tokio::sync::RwLock; + +use crate::error::ChunkCacheError; +use crate::{CacheRange, ChunkCache}; + +#[derive(Debug, Clone)] +struct MemoryCacheItem { + range: ChunkRange, + chunk_byte_indices: Vec, + data: Vec, +} + +#[derive(Debug, Clone, Default)] +struct CacheState { + inner: HashMap>, + num_items: usize, + total_bytes: u64, +} + +impl CacheState { + fn find_match(&self, key: &Key, range: &ChunkRange) -> Option<&MemoryCacheItem> { + let items = self.inner.get(key)?; + + items + .iter() + .find(|&item| item.range.start <= range.start && range.end <= item.range.end) + .map(|v| v as _) + } +} + +/// MemoryCache is a ChunkCache implementor that stores data in memory +#[derive(Debug, Clone, Default)] +pub struct MemoryCache { + state: Arc>, +} + +impl MemoryCache { + pub async fn num_items(&self) -> usize { + self.state.read().await.num_items + } + + pub async fn total_bytes(&self) -> u64 { + self.state.read().await.total_bytes + } +} + +#[async_trait] +impl ChunkCache for MemoryCache { + async fn get(&self, key: &Key, range: &ChunkRange) -> Result, ChunkCacheError> { + if range.start >= range.end { + return Err(ChunkCacheError::InvalidArguments); + } + + let state = self.state.read().await; + let Some(cache_item) = state.find_match(key, range) else { + return Ok(None); + }; + + // Extract the requested range from the cached item + let start_idx = (range.start - cache_item.range.start) as usize; + let end_idx = (range.end - cache_item.range.start) as usize; + + if end_idx >= cache_item.chunk_byte_indices.len() { + return Err(ChunkCacheError::BadRange); + } + + let start_byte = cache_item.chunk_byte_indices[start_idx] as usize; + let end_byte = cache_item.chunk_byte_indices[end_idx] as usize; + + if end_byte > cache_item.data.len() { + return Err(ChunkCacheError::BadRange); + } + + let data = cache_item.data[start_byte..end_byte].to_vec(); + let offsets: Vec = cache_item.chunk_byte_indices[start_idx..=end_idx] + .iter() + .map(|v| v - cache_item.chunk_byte_indices[start_idx]) + .collect(); + + Ok(Some(CacheRange { + offsets, + data, + range: *range, + })) + } + + async fn put( + &self, + key: &Key, + range: &ChunkRange, + chunk_byte_indices: &[u32], + data: &[u8], + ) -> Result<(), ChunkCacheError> { + // Validate inputs + if range.start >= range.end + || chunk_byte_indices.len() != (range.end - range.start + 1) as usize + || chunk_byte_indices.is_empty() + || chunk_byte_indices[0] != 0 + || *chunk_byte_indices.last().unwrap() as usize != data.len() + || !strictly_increasing(chunk_byte_indices) + { + return Err(ChunkCacheError::InvalidArguments); + } + + let data_len = data.len() as u64; + + let mut state = self.state.write().await; + + // Check if we already have this exact range cached + if let Some(items) = state.inner.get(key) { + for item in items.iter() { + if item.range == *range { + // Already cached + return Ok(()); + } + } + } + + // Add the new item + let cache_item = MemoryCacheItem { + range: *range, + chunk_byte_indices: chunk_byte_indices.to_vec(), + data: data.to_vec(), + }; + + state.total_bytes += data_len; + state.num_items += 1; + + state.inner.entry(key.clone()).or_insert_with(Vec::new).push(cache_item); + + Ok(()) + } +} + +fn strictly_increasing(chunk_byte_indices: &[u32]) -> bool { + for i in 1..chunk_byte_indices.len() { + if chunk_byte_indices[i - 1] >= chunk_byte_indices[i] { + return false; + } + } + true +} + +#[cfg(test)] +mod tests { + use cas_types::{ChunkRange, Key}; + + use super::*; + + #[tokio::test] + async fn test_memory_cache_basic() { + let cache = MemoryCache::default(); + let key = Key { + prefix: "test".to_string(), + hash: merklehash::MerkleHash::default(), + }; + let range = ChunkRange::new(0, 10); + let chunk_byte_indices: Vec = (0..=10).map(|i| i * 100).collect(); + let data = vec![0u8; 1000]; + + // Put data in cache + cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap(); + + // Get data from cache + let result = cache.get(&key, &range).await.unwrap(); + assert!(result.is_some()); + let cache_range = result.unwrap(); + assert_eq!(cache_range.data, data); + assert_eq!(cache_range.range, range); + } + + #[tokio::test] + async fn test_memory_cache_invalid_inputs() { + let cache = MemoryCache::default(); + let key = Key { + prefix: "test".to_string(), + hash: merklehash::MerkleHash::default(), + }; + + // Test invalid range (start >= end) + let invalid_range = ChunkRange::new(10, 5); + let data = vec![0u8; 100]; + let chunk_byte_indices = vec![0, 50, 100]; + assert!(cache.put(&key, &invalid_range, &chunk_byte_indices, &data).await.is_err()); + + // Test non-increasing chunk indices + let range = ChunkRange::new(0, 2); + let invalid_indices = vec![0, 50, 20]; + assert!(cache.put(&key, &range, &invalid_indices, &data).await.is_err()); + + // Test empty indices + let empty_indices: Vec = vec![]; + assert!(cache.put(&key, &range, &empty_indices, &data).await.is_err()); + + // Test non-zero first index + let invalid_start = vec![10, 50, 100]; + assert!(cache.put(&key, &range, &invalid_start, &data).await.is_err()); + + // Test last index != data length + let invalid_end = vec![0, 50, 80]; + assert!(cache.put(&key, &range, &invalid_end, &data).await.is_err()); + } + + #[tokio::test] + async fn test_memory_cache_duplicate_put() { + let cache = MemoryCache::default(); + let key = Key { + prefix: "test".to_string(), + hash: merklehash::MerkleHash::default(), + }; + let range = ChunkRange::new(0, 2); + let chunk_byte_indices = vec![0, 50, 100]; + let data = vec![0u8; 100]; + + // First put should succeed + cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap(); + let initial_items = cache.num_items().await; + let initial_bytes = cache.total_bytes().await; + + // Second put of same range should not increase counters + cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap(); + assert_eq!(cache.num_items().await, initial_items); + assert_eq!(cache.total_bytes().await, initial_bytes); + } + + #[tokio::test] + async fn test_memory_cache_partial_range() { + let cache = MemoryCache::default(); + let key = Key { + prefix: "test".to_string(), + hash: merklehash::MerkleHash::default(), + }; + let range = ChunkRange::new(0, 2); + let chunk_byte_indices = vec![0, 50, 100]; + let data = vec![1u8; 100]; + + cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap(); + + // Get partial range + let partial_range = ChunkRange::new(1, 2); + let result = cache.get(&key, &partial_range).await.unwrap().unwrap(); + assert_eq!(result.range, partial_range); + assert_eq!(result.data.len(), 50); + assert_eq!(result.offsets, vec![0, 50]); + } +}