|
| 1 | +/* |
| 2 | +TTLMap is a DashMap that automatically removes entries after a specified time-to-live (TTL). |
| 3 | +
|
| 4 | +How the Time Wheel Works |
| 5 | +
|
| 6 | +Time Buckets: [0] [1] [2] [3] [4] [5] [6] [7] ... |
| 7 | +Current Time: ^ |
| 8 | + | |
| 9 | + time % buckets.len() |
| 10 | +
|
| 11 | +When inserting key "A" at time=2: |
| 12 | +- Key "A" goes into bucket[(2-1) % 8] = bucket[1] |
| 13 | +- Key "A" will be expired when time advances to bucket[1] again |
| 14 | +
|
| 15 | +Generally, keys in a bucket expire when the wheel makes a full rotation, making |
| 16 | +the total TTL equal to the tick duration * buckets.len(). |
| 17 | +
|
| 18 | +Usage |
| 19 | +```rust |
| 20 | +let params = TTLMapParams { tick: Duration::from_secs(30), ttl: Duration::from_mins(5) }; |
| 21 | +let ttl_map = TTLMap::new(params).await.unwrap(); |
| 22 | +let value = ttl_map.get_or_init(key, || initial_value).await; |
| 23 | +``` |
| 24 | +
|
| 25 | +TODO: If an existing entry is accessed, we don't extend its TTL. It's unclear if this is |
| 26 | +necessary for any use cases. This functionality could be added if needed. |
| 27 | + */ |
| 28 | +use dashmap::{DashMap, Entry}; |
| 29 | +use datafusion::error::DataFusionError; |
| 30 | +use std::collections::HashSet; |
| 31 | +use std::hash::Hash; |
| 32 | +use std::sync::atomic::AtomicU64; |
| 33 | +use std::sync::Arc; |
| 34 | +use std::time::Duration; |
| 35 | +use tokio::sync::Mutex; |
| 36 | + |
| 37 | +// TTLMap is a key-value store that automatically removes entries after a specified time-to-live. |
| 38 | +pub struct TTLMap<K, V> { |
| 39 | + /// Time wheel buckets containing keys to be expired. Each bucket epresents |
| 40 | + /// a time slot. Keys in bucket[i] will be expired when time % buckets.len() == i |
| 41 | + buckets: Arc<Mutex<Vec<HashSet<K>>>>, |
| 42 | + |
| 43 | + /// The actual key-value storage using DashMap for concurrent access |
| 44 | + data: Arc<DashMap<K, V>>, |
| 45 | + |
| 46 | + /// Atomic counter tracking the current time slot for the time wheel. |
| 47 | + /// Incremented by the background GC task every `tick` duration. |
| 48 | + time: Arc<AtomicU64>, |
| 49 | + |
| 50 | + /// Background task handle for the garbage collection process. |
| 51 | + /// When dropped, the GC task is automatically aborted. |
| 52 | + _task: Option<tokio::task::JoinHandle<()>>, |
| 53 | + |
| 54 | + // grandularity of the time wheel. How often a bucket is cleared. |
| 55 | + tick: Duration, |
| 56 | +} |
| 57 | + |
| 58 | +pub struct TTLMapParams { |
| 59 | + // tick is how often the map is checks for expired entries |
| 60 | + // must be less than ttl |
| 61 | + pub tick: Duration, |
| 62 | + // ttl is the time-to-live for entries |
| 63 | + pub ttl: Duration, |
| 64 | +} |
| 65 | + |
| 66 | +impl Default for TTLMapParams { |
| 67 | + fn default() -> Self { |
| 68 | + Self { |
| 69 | + tick: Duration::from_secs(3), |
| 70 | + ttl: Duration::from_secs(60), |
| 71 | + } |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +impl<K, V> TTLMap<K, V> |
| 76 | +where |
| 77 | + K: Eq + Hash + Send + Sync + Clone + 'static, |
| 78 | + V: Default + Clone + Send + Sync + 'static, |
| 79 | +{ |
| 80 | + // new creates a new TTLMap. |
| 81 | + pub async fn new(params: TTLMapParams) -> Result<Self, DataFusionError> { |
| 82 | + if params.tick > params.ttl { |
| 83 | + return Err(DataFusionError::Configuration( |
| 84 | + "tick duration must be less than or equal to ttl duration".to_string(), |
| 85 | + )); |
| 86 | + } |
| 87 | + let mut map = Self::_new(params.tick, params.ttl).await; |
| 88 | + map._start_gc(); |
| 89 | + Ok(map) |
| 90 | + } |
| 91 | + |
| 92 | + async fn _new(tick: Duration, ttl: Duration) -> Self { |
| 93 | + let bucket_count = (ttl.as_millis() / tick.as_millis()) as usize; |
| 94 | + let mut buckets = Vec::with_capacity(bucket_count); |
| 95 | + for _ in 0..bucket_count { |
| 96 | + buckets.push(HashSet::new()); |
| 97 | + } |
| 98 | + let stage_targets = Arc::new(DashMap::new()); |
| 99 | + let time_wheel = Arc::new(Mutex::new(buckets)); |
| 100 | + let time = Arc::new(AtomicU64::new(0)); |
| 101 | + Self { |
| 102 | + buckets: time_wheel, |
| 103 | + data: stage_targets, |
| 104 | + time, |
| 105 | + _task: None, |
| 106 | + tick, |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + // Start and set the background GC task. |
| 111 | + fn _start_gc(&mut self) { |
| 112 | + self._task = Some(tokio::spawn(Self::run_gc_loop( |
| 113 | + self.data.clone(), |
| 114 | + self.buckets.clone(), |
| 115 | + self.time.clone(), |
| 116 | + self.tick, |
| 117 | + ))) |
| 118 | + } |
| 119 | + |
| 120 | + /// get_or_default executes the provided closure with a reference to the map entry for the given key. |
| 121 | + /// If the key does not exist, it inserts a new entry with the default value. |
| 122 | + pub async fn get_or_init<F>(&self, key: K, f: F) -> V |
| 123 | + where |
| 124 | + F: FnOnce() -> V, |
| 125 | + { |
| 126 | + let mut new_entry = false; |
| 127 | + let value = match self.data.entry(key.clone()) { |
| 128 | + Entry::Vacant(entry) => { |
| 129 | + let value = f(); |
| 130 | + entry.insert(value.clone()); |
| 131 | + new_entry = true; |
| 132 | + value |
| 133 | + } |
| 134 | + Entry::Occupied(entry) => entry.get().clone(), |
| 135 | + }; |
| 136 | + |
| 137 | + // Insert the key into the previous bucket, meaning the key will be evicted |
| 138 | + // when the wheel completes a full rotation. |
| 139 | + if new_entry { |
| 140 | + let time = self.time.load(std::sync::atomic::Ordering::SeqCst); |
| 141 | + { |
| 142 | + let mut buckets = self.buckets.lock().await; |
| 143 | + let bucket_index = (time.wrapping_sub(1)) % buckets.len() as u64; |
| 144 | + buckets[bucket_index as usize].insert(key); |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + value |
| 149 | + } |
| 150 | + |
| 151 | + /// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The |
| 152 | + /// function terminates if `shutdown` is signalled. |
| 153 | + async fn run_gc_loop( |
| 154 | + map: Arc<DashMap<K, V>>, |
| 155 | + time_wheel: Arc<Mutex<Vec<HashSet<K>>>>, |
| 156 | + time: Arc<AtomicU64>, |
| 157 | + period: Duration, |
| 158 | + ) { |
| 159 | + loop { |
| 160 | + Self::gc(map.clone(), time_wheel.clone(), time.clone()).await; |
| 161 | + tokio::time::sleep(period).await; |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + /// gc clears expired entries from the map and advances time by 1. |
| 166 | + async fn gc( |
| 167 | + map: Arc<DashMap<K, V>>, |
| 168 | + time_wheel: Arc<Mutex<Vec<HashSet<K>>>>, |
| 169 | + time: Arc<AtomicU64>, |
| 170 | + ) { |
| 171 | + let keys = { |
| 172 | + let mut guard = time_wheel.lock().await; |
| 173 | + let len = guard.len(); |
| 174 | + let index = time.load(std::sync::atomic::Ordering::SeqCst) % len as u64; |
| 175 | + // Replace the HashSet at the index with an empty one and return the original |
| 176 | + std::mem::replace(&mut guard[index as usize], HashSet::new()) |
| 177 | + }; |
| 178 | + |
| 179 | + // Remove expired keys from the map. |
| 180 | + // TODO: it may be worth exploring if we can group keys by shard and do a batched |
| 181 | + // remove. |
| 182 | + for key in keys { |
| 183 | + map.remove(&key); |
| 184 | + } |
| 185 | + // May wrap. |
| 186 | + time.fetch_add(1, std::sync::atomic::Ordering::SeqCst); |
| 187 | + } |
| 188 | +} |
| 189 | + |
| 190 | +#[cfg(test)] |
| 191 | +mod tests { |
| 192 | + use super::*; |
| 193 | + use std::sync::atomic::Ordering; |
| 194 | + use tokio::time::{sleep, Duration}; |
| 195 | + |
| 196 | + #[tokio::test] |
| 197 | + async fn test_basic_insert_and_get() { |
| 198 | + let ttl_map = |
| 199 | + TTLMap::<String, i32>::_new(Duration::from_millis(100), Duration::from_secs(1)).await; |
| 200 | + |
| 201 | + ttl_map.get_or_init("key1".to_string(), || 42).await; |
| 202 | + |
| 203 | + let value = ttl_map.get_or_init("key1".to_string(), || 0).await; |
| 204 | + assert_eq!(value, 42); |
| 205 | + } |
| 206 | + |
| 207 | + #[tokio::test] |
| 208 | + async fn test_time_wheel_bucket_calculation() { |
| 209 | + let ttl_map = |
| 210 | + TTLMap::<String, i32>::_new(Duration::from_millis(100), Duration::from_secs(1)).await; |
| 211 | + |
| 212 | + // With 1s TTL and 100ms tick, we should have 10 buckets |
| 213 | + assert_eq!(ttl_map.buckets.lock().await.len(), 10); |
| 214 | + } |
| 215 | + |
| 216 | + #[tokio::test] |
| 217 | + async fn test_gc_expiration() { |
| 218 | + let ttl_map = |
| 219 | + TTLMap::<String, i32>::_new(Duration::from_millis(100), Duration::from_secs(1)).await; |
| 220 | + |
| 221 | + // Initial batch of entries |
| 222 | + ttl_map.get_or_init("key1".to_string(), || 42).await; |
| 223 | + ttl_map.get_or_init("key2".to_string(), || 84).await; |
| 224 | + assert_eq!(ttl_map.data.len(), 2); |
| 225 | + |
| 226 | + // Run partial GC cycles (should not expire yet) |
| 227 | + for _ in 0..5 { |
| 228 | + TTLMap::gc( |
| 229 | + ttl_map.data.clone(), |
| 230 | + ttl_map.buckets.clone(), |
| 231 | + ttl_map.time.clone(), |
| 232 | + ) |
| 233 | + .await; |
| 234 | + } |
| 235 | + assert_eq!(ttl_map.data.len(), 2); // Still there |
| 236 | + |
| 237 | + // Add more entries mid-cycle |
| 238 | + ttl_map.get_or_init("key3".to_string(), || 168).await; |
| 239 | + ttl_map.get_or_init("key4".to_string(), || 0).await; // Default value (0) |
| 240 | + ttl_map.get_or_init("key5".to_string(), || 210).await; |
| 241 | + assert_eq!(ttl_map.data.len(), 5); |
| 242 | + |
| 243 | + // Verify default value was set |
| 244 | + let default_value = ttl_map.get_or_init("key4".to_string(), || 0).await; |
| 245 | + assert_eq!(default_value, 0); |
| 246 | + |
| 247 | + // Complete the first rotation to expire initial entries |
| 248 | + for _ in 5..10 { |
| 249 | + TTLMap::gc( |
| 250 | + ttl_map.data.clone(), |
| 251 | + ttl_map.buckets.clone(), |
| 252 | + ttl_map.time.clone(), |
| 253 | + ) |
| 254 | + .await; |
| 255 | + } |
| 256 | + assert_eq!(ttl_map.data.len(), 3); // Initial entries expired, new entries still alive |
| 257 | + |
| 258 | + // Add entries after expiration |
| 259 | + ttl_map.get_or_init("new_key1".to_string(), || 999).await; |
| 260 | + ttl_map.get_or_init("new_key2".to_string(), || 0).await; // Default value |
| 261 | + assert_eq!(ttl_map.data.len(), 5); // 3 from mid-cycle + 2 new ones |
| 262 | + |
| 263 | + // Verify values |
| 264 | + let value1 = ttl_map.get_or_init("new_key1".to_string(), || 0).await; |
| 265 | + assert_eq!(value1, 999); |
| 266 | + let value2 = ttl_map.get_or_init("new_key2".to_string(), || 0).await; |
| 267 | + assert_eq!(value2, 0); |
| 268 | + |
| 269 | + // Run additional GC cycles to expire remaining entries |
| 270 | + // Mid-cycle entries (bucket 4) expire at time=14, late entries (bucket 9) expire at time=19 |
| 271 | + for _ in 10..20 { |
| 272 | + TTLMap::gc( |
| 273 | + ttl_map.data.clone(), |
| 274 | + ttl_map.buckets.clone(), |
| 275 | + ttl_map.time.clone(), |
| 276 | + ) |
| 277 | + .await; |
| 278 | + } |
| 279 | + assert_eq!(ttl_map.data.len(), 0); // All entries expired |
| 280 | + } |
| 281 | + |
| 282 | + #[tokio::test] |
| 283 | + async fn test_concurrent_gc_and_access() { |
| 284 | + let ttl_map = TTLMap::<String, i32>::new(TTLMapParams { |
| 285 | + tick: Duration::from_millis(2), |
| 286 | + ttl: Duration::from_millis(10), |
| 287 | + }) |
| 288 | + .await |
| 289 | + .unwrap(); |
| 290 | + |
| 291 | + assert!(ttl_map._task.is_some()); |
| 292 | + |
| 293 | + let ttl_map = Arc::new(ttl_map); |
| 294 | + |
| 295 | + // Spawn 5 concurrent tasks |
| 296 | + let mut handles = Vec::new(); |
| 297 | + for task_id in 0..5 { |
| 298 | + let map = Arc::clone(&ttl_map); |
| 299 | + let handle = tokio::spawn(async move { |
| 300 | + for i in 0..20 { |
| 301 | + let key = format!("task{}_key{}", task_id, i % 4); |
| 302 | + map.get_or_init(key, || task_id * 100 + i).await; |
| 303 | + sleep(Duration::from_millis(1)).await; |
| 304 | + } |
| 305 | + }); |
| 306 | + handles.push(handle); |
| 307 | + } |
| 308 | + |
| 309 | + // Wait for all tasks to complete |
| 310 | + for handle in handles { |
| 311 | + handle.await.unwrap(); |
| 312 | + } |
| 313 | + } |
| 314 | + |
| 315 | + #[tokio::test] |
| 316 | + async fn test_wraparound_time() { |
| 317 | + let ttl_map = TTLMap::<String, i32>::_new( |
| 318 | + Duration::from_millis(10), |
| 319 | + Duration::from_millis(20), // 2 buckets |
| 320 | + ) |
| 321 | + .await; |
| 322 | + |
| 323 | + // Manually set time near overflow |
| 324 | + ttl_map.time.store(u64::MAX - 2, Ordering::SeqCst); |
| 325 | + |
| 326 | + ttl_map.get_or_init("test_key".to_string(), || 999).await; |
| 327 | + |
| 328 | + // Run GC to cause time wraparound |
| 329 | + for _ in 0..5 { |
| 330 | + TTLMap::gc( |
| 331 | + ttl_map.data.clone(), |
| 332 | + ttl_map.buckets.clone(), |
| 333 | + ttl_map.time.clone(), |
| 334 | + ) |
| 335 | + .await; |
| 336 | + } |
| 337 | + |
| 338 | + // Entry should be expired and time should have wrapped |
| 339 | + assert_eq!(ttl_map.data.len(), 0); |
| 340 | + let final_time = ttl_map.time.load(Ordering::SeqCst); |
| 341 | + assert!(final_time < 100); |
| 342 | + } |
| 343 | +} |
0 commit comments