Skip to content

Commit 3e4dd65

Browse files
Create TTL map with time wheel architecture
This changes adds a DashMap-like struct which has a background tasks to clean up entries that have outlived a configurable TTL. This struct is simliar to https://github.com/moka-rs/moka, which also uses time wheels. Having our own struct avoids introducing a large dependency, which keeps this project closer to vanilla datafusion.
1 parent 0dbf427 commit 3e4dd65

File tree

2 files changed

+344
-0
lines changed

2 files changed

+344
-0
lines changed

src/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
pub mod ttl_map;
12
pub mod util;

src/common/ttl_map.rs

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
/*
2+
TTLMap is a DashMap that automatically removes entries after a specified time-to-live (TTL).
3+
4+
How the Time Wheel Works
5+
6+
Time Buckets: [0] [1] [2] [3] [4] [5] [6] [7] ...
7+
Current Time: ^
8+
|
9+
time % buckets.len()
10+
11+
When inserting key "A" at time=2:
12+
- Key "A" goes into bucket[(2-1) % 8] = bucket[1]
13+
- Key "A" will be expired when time advances to bucket[1] again
14+
15+
Generally, keys in a bucket expire when the wheel makes a full rotation, making
16+
the total TTL equal to the tick duration * buckets.len().
17+
18+
Usage
19+
```rust
20+
let params = TTLMapParams { tick: Duration::from_secs(30), ttl: Duration::from_mins(5) };
21+
let ttl_map = TTLMap::new(params).await.unwrap();
22+
let value = ttl_map.get_or_init(key, || initial_value).await;
23+
```
24+
25+
TODO: If an existing entry is accessed, we don't extend its TTL. It's unclear if this is
26+
necessary for any use cases. This functionality could be added if needed.
27+
*/
28+
use dashmap::{DashMap, Entry};
29+
use datafusion::error::DataFusionError;
30+
use std::collections::HashSet;
31+
use std::hash::Hash;
32+
use std::sync::atomic::AtomicU64;
33+
use std::sync::Arc;
34+
use std::time::Duration;
35+
use tokio::sync::Mutex;
36+
37+
// TTLMap is a key-value store that automatically removes entries after a specified time-to-live.
38+
pub struct TTLMap<K, V> {
39+
/// Time wheel buckets containing keys to be expired. Each bucket epresents
40+
/// a time slot. Keys in bucket[i] will be expired when time % buckets.len() == i
41+
buckets: Arc<Mutex<Vec<HashSet<K>>>>,
42+
43+
/// The actual key-value storage using DashMap for concurrent access
44+
data: Arc<DashMap<K, V>>,
45+
46+
/// Atomic counter tracking the current time slot for the time wheel.
47+
/// Incremented by the background GC task every `tick` duration.
48+
time: Arc<AtomicU64>,
49+
50+
/// Background task handle for the garbage collection process.
51+
/// When dropped, the GC task is automatically aborted.
52+
_task: Option<tokio::task::JoinHandle<()>>,
53+
54+
// grandularity of the time wheel. How often a bucket is cleared.
55+
tick: Duration,
56+
}
57+
58+
pub struct TTLMapParams {
59+
// tick is how often the map is checks for expired entries
60+
// must be less than ttl
61+
pub tick: Duration,
62+
// ttl is the time-to-live for entries
63+
pub ttl: Duration,
64+
}
65+
66+
impl Default for TTLMapParams {
67+
fn default() -> Self {
68+
Self {
69+
tick: Duration::from_secs(3),
70+
ttl: Duration::from_secs(60),
71+
}
72+
}
73+
}
74+
75+
impl<K, V> TTLMap<K, V>
76+
where
77+
K: Eq + Hash + Send + Sync + Clone + 'static,
78+
V: Default + Clone + Send + Sync + 'static,
79+
{
80+
// new creates a new TTLMap.
81+
pub async fn new(params: TTLMapParams) -> Result<Self, DataFusionError> {
82+
if params.tick > params.ttl {
83+
return Err(DataFusionError::Configuration(
84+
"tick duration must be less than or equal to ttl duration".to_string(),
85+
));
86+
}
87+
let mut map = Self::_new(params.tick, params.ttl).await;
88+
map._start_gc();
89+
Ok(map)
90+
}
91+
92+
async fn _new(tick: Duration, ttl: Duration) -> Self {
93+
let bucket_count = (ttl.as_millis() / tick.as_millis()) as usize;
94+
let mut buckets = Vec::with_capacity(bucket_count);
95+
for _ in 0..bucket_count {
96+
buckets.push(HashSet::new());
97+
}
98+
let stage_targets = Arc::new(DashMap::new());
99+
let time_wheel = Arc::new(Mutex::new(buckets));
100+
let time = Arc::new(AtomicU64::new(0));
101+
Self {
102+
buckets: time_wheel,
103+
data: stage_targets,
104+
time,
105+
_task: None,
106+
tick,
107+
}
108+
}
109+
110+
// Start and set the background GC task.
111+
fn _start_gc(&mut self) {
112+
self._task = Some(tokio::spawn(Self::run_gc_loop(
113+
self.data.clone(),
114+
self.buckets.clone(),
115+
self.time.clone(),
116+
self.tick,
117+
)))
118+
}
119+
120+
/// get_or_default executes the provided closure with a reference to the map entry for the given key.
121+
/// If the key does not exist, it inserts a new entry with the default value.
122+
pub async fn get_or_init<F>(&self, key: K, f: F) -> V
123+
where
124+
F: FnOnce() -> V,
125+
{
126+
let mut new_entry = false;
127+
let value = match self.data.entry(key.clone()) {
128+
Entry::Vacant(entry) => {
129+
let value = f();
130+
entry.insert(value.clone());
131+
new_entry = true;
132+
value
133+
}
134+
Entry::Occupied(entry) => entry.get().clone(),
135+
};
136+
137+
// Insert the key into the previous bucket, meaning the key will be evicted
138+
// when the wheel completes a full rotation.
139+
if new_entry {
140+
let time = self.time.load(std::sync::atomic::Ordering::SeqCst);
141+
{
142+
let mut buckets = self.buckets.lock().await;
143+
let bucket_index = (time.wrapping_sub(1)) % buckets.len() as u64;
144+
buckets[bucket_index as usize].insert(key);
145+
}
146+
}
147+
148+
value
149+
}
150+
151+
/// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
152+
/// function terminates if `shutdown` is signalled.
153+
async fn run_gc_loop(
154+
map: Arc<DashMap<K, V>>,
155+
time_wheel: Arc<Mutex<Vec<HashSet<K>>>>,
156+
time: Arc<AtomicU64>,
157+
period: Duration,
158+
) {
159+
loop {
160+
Self::gc(map.clone(), time_wheel.clone(), time.clone()).await;
161+
tokio::time::sleep(period).await;
162+
}
163+
}
164+
165+
/// gc clears expired entries from the map and advances time by 1.
166+
async fn gc(
167+
map: Arc<DashMap<K, V>>,
168+
time_wheel: Arc<Mutex<Vec<HashSet<K>>>>,
169+
time: Arc<AtomicU64>,
170+
) {
171+
let keys = {
172+
let mut guard = time_wheel.lock().await;
173+
let len = guard.len();
174+
let index = time.load(std::sync::atomic::Ordering::SeqCst) % len as u64;
175+
// Replace the HashSet at the index with an empty one and return the original
176+
std::mem::replace(&mut guard[index as usize], HashSet::new())
177+
};
178+
179+
// Remove expired keys from the map.
180+
// TODO: it may be worth exploring if we can group keys by shard and do a batched
181+
// remove.
182+
for key in keys {
183+
map.remove(&key);
184+
}
185+
// May wrap.
186+
time.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
187+
}
188+
}
189+
190+
#[cfg(test)]
191+
mod tests {
192+
use super::*;
193+
use std::sync::atomic::Ordering;
194+
use tokio::time::{sleep, Duration};
195+
196+
#[tokio::test]
197+
async fn test_basic_insert_and_get() {
198+
let ttl_map =
199+
TTLMap::<String, i32>::_new(Duration::from_millis(100), Duration::from_secs(1)).await;
200+
201+
ttl_map.get_or_init("key1".to_string(), || 42).await;
202+
203+
let value = ttl_map.get_or_init("key1".to_string(), || 0).await;
204+
assert_eq!(value, 42);
205+
}
206+
207+
#[tokio::test]
208+
async fn test_time_wheel_bucket_calculation() {
209+
let ttl_map =
210+
TTLMap::<String, i32>::_new(Duration::from_millis(100), Duration::from_secs(1)).await;
211+
212+
// With 1s TTL and 100ms tick, we should have 10 buckets
213+
assert_eq!(ttl_map.buckets.lock().await.len(), 10);
214+
}
215+
216+
#[tokio::test]
217+
async fn test_gc_expiration() {
218+
let ttl_map =
219+
TTLMap::<String, i32>::_new(Duration::from_millis(100), Duration::from_secs(1)).await;
220+
221+
// Initial batch of entries
222+
ttl_map.get_or_init("key1".to_string(), || 42).await;
223+
ttl_map.get_or_init("key2".to_string(), || 84).await;
224+
assert_eq!(ttl_map.data.len(), 2);
225+
226+
// Run partial GC cycles (should not expire yet)
227+
for _ in 0..5 {
228+
TTLMap::gc(
229+
ttl_map.data.clone(),
230+
ttl_map.buckets.clone(),
231+
ttl_map.time.clone(),
232+
)
233+
.await;
234+
}
235+
assert_eq!(ttl_map.data.len(), 2); // Still there
236+
237+
// Add more entries mid-cycle
238+
ttl_map.get_or_init("key3".to_string(), || 168).await;
239+
ttl_map.get_or_init("key4".to_string(), || 0).await; // Default value (0)
240+
ttl_map.get_or_init("key5".to_string(), || 210).await;
241+
assert_eq!(ttl_map.data.len(), 5);
242+
243+
// Verify default value was set
244+
let default_value = ttl_map.get_or_init("key4".to_string(), || 0).await;
245+
assert_eq!(default_value, 0);
246+
247+
// Complete the first rotation to expire initial entries
248+
for _ in 5..10 {
249+
TTLMap::gc(
250+
ttl_map.data.clone(),
251+
ttl_map.buckets.clone(),
252+
ttl_map.time.clone(),
253+
)
254+
.await;
255+
}
256+
assert_eq!(ttl_map.data.len(), 3); // Initial entries expired, new entries still alive
257+
258+
// Add entries after expiration
259+
ttl_map.get_or_init("new_key1".to_string(), || 999).await;
260+
ttl_map.get_or_init("new_key2".to_string(), || 0).await; // Default value
261+
assert_eq!(ttl_map.data.len(), 5); // 3 from mid-cycle + 2 new ones
262+
263+
// Verify values
264+
let value1 = ttl_map.get_or_init("new_key1".to_string(), || 0).await;
265+
assert_eq!(value1, 999);
266+
let value2 = ttl_map.get_or_init("new_key2".to_string(), || 0).await;
267+
assert_eq!(value2, 0);
268+
269+
// Run additional GC cycles to expire remaining entries
270+
// Mid-cycle entries (bucket 4) expire at time=14, late entries (bucket 9) expire at time=19
271+
for _ in 10..20 {
272+
TTLMap::gc(
273+
ttl_map.data.clone(),
274+
ttl_map.buckets.clone(),
275+
ttl_map.time.clone(),
276+
)
277+
.await;
278+
}
279+
assert_eq!(ttl_map.data.len(), 0); // All entries expired
280+
}
281+
282+
#[tokio::test]
283+
async fn test_concurrent_gc_and_access() {
284+
let ttl_map = TTLMap::<String, i32>::new(TTLMapParams {
285+
tick: Duration::from_millis(2),
286+
ttl: Duration::from_millis(10),
287+
})
288+
.await
289+
.unwrap();
290+
291+
assert!(ttl_map._task.is_some());
292+
293+
let ttl_map = Arc::new(ttl_map);
294+
295+
// Spawn 5 concurrent tasks
296+
let mut handles = Vec::new();
297+
for task_id in 0..5 {
298+
let map = Arc::clone(&ttl_map);
299+
let handle = tokio::spawn(async move {
300+
for i in 0..20 {
301+
let key = format!("task{}_key{}", task_id, i % 4);
302+
map.get_or_init(key, || task_id * 100 + i).await;
303+
sleep(Duration::from_millis(1)).await;
304+
}
305+
});
306+
handles.push(handle);
307+
}
308+
309+
// Wait for all tasks to complete
310+
for handle in handles {
311+
handle.await.unwrap();
312+
}
313+
}
314+
315+
#[tokio::test]
316+
async fn test_wraparound_time() {
317+
let ttl_map = TTLMap::<String, i32>::_new(
318+
Duration::from_millis(10),
319+
Duration::from_millis(20), // 2 buckets
320+
)
321+
.await;
322+
323+
// Manually set time near overflow
324+
ttl_map.time.store(u64::MAX - 2, Ordering::SeqCst);
325+
326+
ttl_map.get_or_init("test_key".to_string(), || 999).await;
327+
328+
// Run GC to cause time wraparound
329+
for _ in 0..5 {
330+
TTLMap::gc(
331+
ttl_map.data.clone(),
332+
ttl_map.buckets.clone(),
333+
ttl_map.time.clone(),
334+
)
335+
.await;
336+
}
337+
338+
// Entry should be expired and time should have wrapped
339+
assert_eq!(ttl_map.data.len(), 0);
340+
let final_time = ttl_map.time.load(Ordering::SeqCst);
341+
assert!(final_time < 100);
342+
}
343+
}

0 commit comments

Comments
 (0)