From 4346b69fc528a3834a68ae66411d35c7efb243ba Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Tue, 12 Aug 2025 19:01:51 -0400 Subject: [PATCH 1/5] Create TTL map with time wheel architecture This change adds a DashMap-like struct which has a background tasks to clean up entries that have outlived a configurable TTL. This struct is simliar to https://github.com/moka-rs/moka, which also uses time wheels. Having our own module avoids introducing a large dependency, which keeps this project closer to vanilla datafusion. This change is meant to be useful for https://github.com/datafusion-contrib/datafusion-distributed/pull/89, where it's possible for `ExecutionStages` to be orphaned in `ArrowFlightEndpoint`. We need an async task to clean up old entries. Informs: https://github.com/datafusion-contrib/datafusion-distributed/issues/90 --- src/common/mod.rs | 1 + src/common/ttl_map.rs | 343 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 344 insertions(+) create mode 100644 src/common/ttl_map.rs diff --git a/src/common/mod.rs b/src/common/mod.rs index 812d1ed..572b996 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1 +1,2 @@ +pub mod ttl_map; pub mod util; diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs new file mode 100644 index 0000000..6930177 --- /dev/null +++ b/src/common/ttl_map.rs @@ -0,0 +1,343 @@ +/* +TTLMap is a DashMap that automatically removes entries after a specified time-to-live (TTL). + +How the Time Wheel Works + +Time Buckets: [0] [1] [2] [3] [4] [5] [6] [7] ... +Current Time: ^ + | + time % buckets.len() + +When inserting key "A" at time=2: +- Key "A" goes into bucket[(2-1) % 8] = bucket[1] +- Key "A" will be expired when time advances to bucket[1] again + +Generally, keys in a bucket expire when the wheel makes a full rotation, making +the total TTL equal to the tick duration * buckets.len(). + +Usage +```rust +let params = TTLMapParams { tick: Duration::from_secs(30), ttl: Duration::from_mins(5) }; +let ttl_map = TTLMap::new(params).await.unwrap(); +let value = ttl_map.get_or_init(key, || initial_value).await; +``` + +TODO: If an existing entry is accessed, we don't extend its TTL. It's unclear if this is +necessary for any use cases. This functionality could be added if needed. + */ +use dashmap::{DashMap, Entry}; +use datafusion::error::DataFusionError; +use std::collections::HashSet; +use std::hash::Hash; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; + +// TTLMap is a key-value store that automatically removes entries after a specified time-to-live. +pub struct TTLMap { + /// Time wheel buckets containing keys to be expired. Each bucket epresents + /// a time slot. Keys in bucket[i] will be expired when time % buckets.len() == i + buckets: Arc>>>, + + /// The actual key-value storage using DashMap for concurrent access + data: Arc>, + + /// Atomic counter tracking the current time slot for the time wheel. + /// Incremented by the background GC task every `tick` duration. + time: Arc, + + /// Background task handle for the garbage collection process. + /// When dropped, the GC task is automatically aborted. + _task: Option>, + + // grandularity of the time wheel. How often a bucket is cleared. + tick: Duration, +} + +pub struct TTLMapParams { + // tick is how often the map is checks for expired entries + // must be less than ttl + pub tick: Duration, + // ttl is the time-to-live for entries + pub ttl: Duration, +} + +impl Default for TTLMapParams { + fn default() -> Self { + Self { + tick: Duration::from_secs(3), + ttl: Duration::from_secs(60), + } + } +} + +impl TTLMap +where + K: Eq + Hash + Send + Sync + Clone + 'static, + V: Default + Clone + Send + Sync + 'static, +{ + // new creates a new TTLMap. + pub async fn new(params: TTLMapParams) -> Result { + if params.tick > params.ttl { + return Err(DataFusionError::Configuration( + "tick duration must be less than or equal to ttl duration".to_string(), + )); + } + let mut map = Self::_new(params.tick, params.ttl).await; + map._start_gc(); + Ok(map) + } + + async fn _new(tick: Duration, ttl: Duration) -> Self { + let bucket_count = (ttl.as_millis() / tick.as_millis()) as usize; + let mut buckets = Vec::with_capacity(bucket_count); + for _ in 0..bucket_count { + buckets.push(HashSet::new()); + } + let stage_targets = Arc::new(DashMap::new()); + let time_wheel = Arc::new(Mutex::new(buckets)); + let time = Arc::new(AtomicU64::new(0)); + Self { + buckets: time_wheel, + data: stage_targets, + time, + _task: None, + tick, + } + } + + // Start and set the background GC task. + fn _start_gc(&mut self) { + self._task = Some(tokio::spawn(Self::run_gc_loop( + self.data.clone(), + self.buckets.clone(), + self.time.clone(), + self.tick, + ))) + } + + /// get_or_default executes the provided closure with a reference to the map entry for the given key. + /// If the key does not exist, it inserts a new entry with the default value. + pub async fn get_or_init(&self, key: K, f: F) -> V + where + F: FnOnce() -> V, + { + let mut new_entry = false; + let value = match self.data.entry(key.clone()) { + Entry::Vacant(entry) => { + let value = f(); + entry.insert(value.clone()); + new_entry = true; + value + } + Entry::Occupied(entry) => entry.get().clone(), + }; + + // Insert the key into the previous bucket, meaning the key will be evicted + // when the wheel completes a full rotation. + if new_entry { + let time = self.time.load(std::sync::atomic::Ordering::SeqCst); + { + let mut buckets = self.buckets.lock().await; + let bucket_index = (time.wrapping_sub(1)) % buckets.len() as u64; + buckets[bucket_index as usize].insert(key); + } + } + + value + } + + /// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The + /// function terminates if `shutdown` is signalled. + async fn run_gc_loop( + map: Arc>, + time_wheel: Arc>>>, + time: Arc, + period: Duration, + ) { + loop { + Self::gc(map.clone(), time_wheel.clone(), time.clone()).await; + tokio::time::sleep(period).await; + } + } + + /// gc clears expired entries from the map and advances time by 1. + async fn gc( + map: Arc>, + time_wheel: Arc>>>, + time: Arc, + ) { + let keys = { + let mut guard = time_wheel.lock().await; + let len = guard.len(); + let index = time.load(std::sync::atomic::Ordering::SeqCst) % len as u64; + // Replace the HashSet at the index with an empty one and return the original + std::mem::replace(&mut guard[index as usize], HashSet::new()) + }; + + // Remove expired keys from the map. + // TODO: it may be worth exploring if we can group keys by shard and do a batched + // remove. + for key in keys { + map.remove(&key); + } + // May wrap. + time.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::Ordering; + use tokio::time::{sleep, Duration}; + + #[tokio::test] + async fn test_basic_insert_and_get() { + let ttl_map = + TTLMap::::_new(Duration::from_millis(100), Duration::from_secs(1)).await; + + ttl_map.get_or_init("key1".to_string(), || 42).await; + + let value = ttl_map.get_or_init("key1".to_string(), || 0).await; + assert_eq!(value, 42); + } + + #[tokio::test] + async fn test_time_wheel_bucket_calculation() { + let ttl_map = + TTLMap::::_new(Duration::from_millis(100), Duration::from_secs(1)).await; + + // With 1s TTL and 100ms tick, we should have 10 buckets + assert_eq!(ttl_map.buckets.lock().await.len(), 10); + } + + #[tokio::test] + async fn test_gc_expiration() { + let ttl_map = + TTLMap::::_new(Duration::from_millis(100), Duration::from_secs(1)).await; + + // Initial batch of entries + ttl_map.get_or_init("key1".to_string(), || 42).await; + ttl_map.get_or_init("key2".to_string(), || 84).await; + assert_eq!(ttl_map.data.len(), 2); + + // Run partial GC cycles (should not expire yet) + for _ in 0..5 { + TTLMap::gc( + ttl_map.data.clone(), + ttl_map.buckets.clone(), + ttl_map.time.clone(), + ) + .await; + } + assert_eq!(ttl_map.data.len(), 2); // Still there + + // Add more entries mid-cycle + ttl_map.get_or_init("key3".to_string(), || 168).await; + ttl_map.get_or_init("key4".to_string(), || 0).await; // Default value (0) + ttl_map.get_or_init("key5".to_string(), || 210).await; + assert_eq!(ttl_map.data.len(), 5); + + // Verify default value was set + let default_value = ttl_map.get_or_init("key4".to_string(), || 0).await; + assert_eq!(default_value, 0); + + // Complete the first rotation to expire initial entries + for _ in 5..10 { + TTLMap::gc( + ttl_map.data.clone(), + ttl_map.buckets.clone(), + ttl_map.time.clone(), + ) + .await; + } + assert_eq!(ttl_map.data.len(), 3); // Initial entries expired, new entries still alive + + // Add entries after expiration + ttl_map.get_or_init("new_key1".to_string(), || 999).await; + ttl_map.get_or_init("new_key2".to_string(), || 0).await; // Default value + assert_eq!(ttl_map.data.len(), 5); // 3 from mid-cycle + 2 new ones + + // Verify values + let value1 = ttl_map.get_or_init("new_key1".to_string(), || 0).await; + assert_eq!(value1, 999); + let value2 = ttl_map.get_or_init("new_key2".to_string(), || 0).await; + assert_eq!(value2, 0); + + // Run additional GC cycles to expire remaining entries + // Mid-cycle entries (bucket 4) expire at time=14, late entries (bucket 9) expire at time=19 + for _ in 10..20 { + TTLMap::gc( + ttl_map.data.clone(), + ttl_map.buckets.clone(), + ttl_map.time.clone(), + ) + .await; + } + assert_eq!(ttl_map.data.len(), 0); // All entries expired + } + + #[tokio::test] + async fn test_concurrent_gc_and_access() { + let ttl_map = TTLMap::::new(TTLMapParams { + tick: Duration::from_millis(2), + ttl: Duration::from_millis(10), + }) + .await + .unwrap(); + + assert!(ttl_map._task.is_some()); + + let ttl_map = Arc::new(ttl_map); + + // Spawn 5 concurrent tasks + let mut handles = Vec::new(); + for task_id in 0..5 { + let map = Arc::clone(&ttl_map); + let handle = tokio::spawn(async move { + for i in 0..20 { + let key = format!("task{}_key{}", task_id, i % 4); + map.get_or_init(key, || task_id * 100 + i).await; + sleep(Duration::from_millis(1)).await; + } + }); + handles.push(handle); + } + + // Wait for all tasks to complete + for handle in handles { + handle.await.unwrap(); + } + } + + #[tokio::test] + async fn test_wraparound_time() { + let ttl_map = TTLMap::::_new( + Duration::from_millis(10), + Duration::from_millis(20), // 2 buckets + ) + .await; + + // Manually set time near overflow + ttl_map.time.store(u64::MAX - 2, Ordering::SeqCst); + + ttl_map.get_or_init("test_key".to_string(), || 999).await; + + // Run GC to cause time wraparound + for _ in 0..5 { + TTLMap::gc( + ttl_map.data.clone(), + ttl_map.buckets.clone(), + ttl_map.time.clone(), + ) + .await; + } + + // Entry should be expired and time should have wrapped + assert_eq!(ttl_map.data.len(), 0); + let final_time = ttl_map.time.load(Ordering::SeqCst); + assert!(final_time < 100); + } +} From b557c6317d46fae2ee777689548f384beab9ef59 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Wed, 13 Aug 2025 14:37:56 -0400 Subject: [PATCH 2/5] add bench --- src/common/ttl_map.rs | 68 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index 6930177..c301306 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -90,7 +90,7 @@ where } async fn _new(tick: Duration, ttl: Duration) -> Self { - let bucket_count = (ttl.as_millis() / tick.as_millis()) as usize; + let bucket_count = (ttl.as_nanos() / tick.as_nanos()) as usize; let mut buckets = Vec::with_capacity(bucket_count); for _ in 0..bucket_count { buckets.push(HashSet::new()); @@ -340,4 +340,70 @@ mod tests { let final_time = ttl_map.time.load(Ordering::SeqCst); assert!(final_time < 100); } + + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn bench_lock_contention() { + use std::time::Instant; + + let ttl_map = TTLMap::::new(TTLMapParams { + tick: Duration::from_micros(1), + ttl: Duration::from_micros(2), + }) + .await + .unwrap(); + + let ttl_map = Arc::new(ttl_map); + + let key_count = 10; + let start_time = Instant::now(); + let operations_per_task = 1_000_000; + let task_count = 100; + + // Spawn 10 tasks that repeatedly read the same keys + let mut handles = Vec::new(); + for task_id in 0..task_count { + let map = Arc::clone(&ttl_map); + let handle = tokio::spawn(async move { + let mut local_ops = 0; + for i in 0..operations_per_task { + // All tasks fight for the same keys - maximum contention + let key = format!("key{}", i % key_count); + let _value = map.get_or_init(key, || task_id * 1000 + i).await; + local_ops += 1; + + // Small yield to allow GC to run frequently + if i % 10 == 0 { + tokio::task::yield_now().await; + } + } + local_ops + }); + handles.push(handle); + } + + // Wait for all tasks and collect operation counts + let mut total_operations = 0; + for handle in handles { + total_operations += handle.await.unwrap(); + } + + let elapsed = start_time.elapsed(); + let ops_per_second = total_operations as f64 / elapsed.as_secs_f64(); + let avg_latency_us = elapsed.as_micros() as f64 / total_operations as f64; + + println!("\n=== TTLMap Lock Contention Benchmark ==="); + println!("Tasks: {}", task_count); + println!("Operations per task: {}", operations_per_task); + println!("Total operations: {}", total_operations); + println!("Total time: {:.2?}", elapsed); + println!("Throughput: {:.0} ops/sec", ops_per_second); + println!("Average latency: {:.2} μs per operation", avg_latency_us); + println!("Entries remaining: {}", ttl_map.data.len()); + + // The benchmark passes if it completes without deadlocks + // Performance metrics are printed for analysis + assert!(ops_per_second > 0.0); // Sanity check + } + } From e99e38084e83cb07e0ff83f75ea2e5451b533d6c Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Thu, 14 Aug 2025 08:41:37 +0200 Subject: [PATCH 3/5] Benchmark suggestion --- src/common/ttl_map.rs | 68 +++++++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index c301306..266ac52 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -29,7 +29,8 @@ use dashmap::{DashMap, Entry}; use datafusion::error::DataFusionError; use std::collections::HashSet; use std::hash::Hash; -use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering::Relaxed; +use std::sync::atomic::{AtomicU64, AtomicUsize}; use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; @@ -53,6 +54,9 @@ pub struct TTLMap { // grandularity of the time wheel. How often a bucket is cleared. tick: Duration, + + dash_map_lock_contention_time: AtomicUsize, + mutex_lock_contention_time: AtomicUsize, } pub struct TTLMapParams { @@ -104,6 +108,8 @@ where time, _task: None, tick, + dash_map_lock_contention_time: AtomicUsize::new(0), + mutex_lock_contention_time: AtomicUsize::new(0), } } @@ -124,7 +130,11 @@ where F: FnOnce() -> V, { let mut new_entry = false; - let value = match self.data.entry(key.clone()) { + let start = std::time::Instant::now(); + let entry = self.data.entry(key.clone()); + self.dash_map_lock_contention_time + .fetch_add(start.elapsed().as_nanos() as usize, Relaxed); + let value = match entry { Entry::Vacant(entry) => { let value = f(); entry.insert(value.clone()); @@ -139,7 +149,10 @@ where if new_entry { let time = self.time.load(std::sync::atomic::Ordering::SeqCst); { + let start = std::time::Instant::now(); let mut buckets = self.buckets.lock().await; + self.mutex_lock_contention_time + .fetch_add(start.elapsed().as_nanos() as usize, Relaxed); let bucket_index = (time.wrapping_sub(1)) % buckets.len() as u64; buckets[bucket_index as usize].insert(key); } @@ -341,12 +354,11 @@ mod tests { assert!(final_time < 100); } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[tokio::test(flavor = "multi_thread", worker_threads = 16)] async fn bench_lock_contention() { use std::time::Instant; - let ttl_map = TTLMap::::new(TTLMapParams { + let ttl_map = TTLMap::::new(TTLMapParams { tick: Duration::from_micros(1), ttl: Duration::from_micros(2), }) @@ -355,55 +367,43 @@ mod tests { let ttl_map = Arc::new(ttl_map); - let key_count = 10; let start_time = Instant::now(); - let operations_per_task = 1_000_000; - let task_count = 100; + let task_count = 100_000; // Spawn 10 tasks that repeatedly read the same keys let mut handles = Vec::new(); for task_id in 0..task_count { let map = Arc::clone(&ttl_map); let handle = tokio::spawn(async move { - let mut local_ops = 0; - for i in 0..operations_per_task { - // All tasks fight for the same keys - maximum contention - let key = format!("key{}", i % key_count); - let _value = map.get_or_init(key, || task_id * 1000 + i).await; - local_ops += 1; - - // Small yield to allow GC to run frequently - if i % 10 == 0 { - tokio::task::yield_now().await; - } - } - local_ops + // All tasks fight for the same keys - maximum contention + let start = Instant::now(); + let _value = map.get_or_init(rand::random(), || task_id * 1000).await; + start.elapsed().as_nanos() }); handles.push(handle); } // Wait for all tasks and collect operation counts - let mut total_operations = 0; + let mut avg_time = 0; for handle in handles { - total_operations += handle.await.unwrap(); + avg_time += handle.await.unwrap(); } + avg_time /= task_count as u128; let elapsed = start_time.elapsed(); - let ops_per_second = total_operations as f64 / elapsed.as_secs_f64(); - let avg_latency_us = elapsed.as_micros() as f64 / total_operations as f64; println!("\n=== TTLMap Lock Contention Benchmark ==="); println!("Tasks: {}", task_count); - println!("Operations per task: {}", operations_per_task); - println!("Total operations: {}", total_operations); println!("Total time: {:.2?}", elapsed); - println!("Throughput: {:.0} ops/sec", ops_per_second); - println!("Average latency: {:.2} μs per operation", avg_latency_us); + println!("Average latency: {:.2} μs per operation", avg_time / 1_000); println!("Entries remaining: {}", ttl_map.data.len()); - - // The benchmark passes if it completes without deadlocks - // Performance metrics are printed for analysis - assert!(ops_per_second > 0.0); // Sanity check + println!( + "DashMap Lock contention time: {}ms", + ttl_map.dash_map_lock_contention_time.load(Ordering::SeqCst) / 1_000_000 + ); + println!( + "Mutex Lock contention time: {}ms", + ttl_map.mutex_lock_contention_time.load(Ordering::SeqCst) / 1_000_000 + ); } - } From 248a7b7e4189bee1fa35930806c3dbdec28ccab9 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Thu, 14 Aug 2025 22:35:32 -0400 Subject: [PATCH 4/5] add contention factor --- src/common/ttl_map.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index 266ac52..4926d1b 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -131,10 +131,7 @@ where { let mut new_entry = false; let start = std::time::Instant::now(); - let entry = self.data.entry(key.clone()); - self.dash_map_lock_contention_time - .fetch_add(start.elapsed().as_nanos() as usize, Relaxed); - let value = match entry { + let value = match self.data.entry(key.clone()) { Entry::Vacant(entry) => { let value = f(); entry.insert(value.clone()); @@ -143,6 +140,8 @@ where } Entry::Occupied(entry) => entry.get().clone(), }; + self.dash_map_lock_contention_time + .fetch_add(start.elapsed().as_nanos() as usize, Relaxed); // Insert the key into the previous bucket, meaning the key will be evicted // when the wheel completes a full rotation. @@ -369,6 +368,7 @@ mod tests { let start_time = Instant::now(); let task_count = 100_000; + let contention_factor = 10; // Spawn 10 tasks that repeatedly read the same keys let mut handles = Vec::new(); @@ -377,7 +377,7 @@ mod tests { let handle = tokio::spawn(async move { // All tasks fight for the same keys - maximum contention let start = Instant::now(); - let _value = map.get_or_init(rand::random(), || task_id * 1000).await; + let _value = map.get_or_init(task_id % (task_count / contention_factor), || task_id * 1000).await; start.elapsed().as_nanos() }); handles.push(handle); From 25ae1c571a836aa869567ca37629fcad340297c6 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Thu, 14 Aug 2025 22:49:35 -0400 Subject: [PATCH 5/5] random --- src/common/ttl_map.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index 4926d1b..d11777b 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -368,7 +368,6 @@ mod tests { let start_time = Instant::now(); let task_count = 100_000; - let contention_factor = 10; // Spawn 10 tasks that repeatedly read the same keys let mut handles = Vec::new(); @@ -377,7 +376,7 @@ mod tests { let handle = tokio::spawn(async move { // All tasks fight for the same keys - maximum contention let start = Instant::now(); - let _value = map.get_or_init(task_id % (task_count / contention_factor), || task_id * 1000).await; + let _value = map.get_or_init(rand::random(), || task_id * 1000).await; start.elapsed().as_nanos() }); handles.push(handle);