Skip to content

Commit 03ffc02

Browse files
Merge pull request #108 from datafusion-contrib/js/hook-up-ttl-map-2
do_get: use TTL map to store task state
2 parents bd512cb + 0a5778e commit 03ffc02

File tree

4 files changed

+291
-32
lines changed

4 files changed

+291
-32
lines changed

src/common/ttl_map.rs

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ where
5656
/// new creates a new Bucket
5757
fn new<V>(data: Arc<DashMap<K, V>>) -> Self
5858
where
59-
V: Default + Clone + Send + Sync + 'static,
59+
V: Send + Sync + 'static,
6060
{
6161
// TODO: To avoid unbounded growth, consider making this bounded. Alternatively, we can
6262
// introduce a dynamic GC period to ensure that GC can keep up.
@@ -84,7 +84,7 @@ where
8484
/// task is responsible for managing a subset of keys in the map.
8585
async fn task<V>(mut rx: UnboundedReceiver<BucketOp<K>>, data: Arc<DashMap<K, V>>)
8686
where
87-
V: Default + Clone + Send + Sync + 'static,
87+
V: Send + Sync + 'static,
8888
{
8989
let mut shard = HashSet::new();
9090
while let Some(op) = rx.recv().await {
@@ -207,7 +207,7 @@ where
207207

208208
/// get_or_default executes the provided closure with a reference to the map entry for the given key.
209209
/// If the key does not exist, it inserts a new entry with the default value.
210-
pub fn get_or_init<F>(&self, key: K, f: F) -> V
210+
pub fn get_or_init<F>(&self, key: K, init: F) -> V
211211
where
212212
F: FnOnce() -> V,
213213
{
@@ -218,7 +218,7 @@ where
218218

219219
let value = match self.data.entry(key.clone()) {
220220
Entry::Vacant(entry) => {
221-
let value = f();
221+
let value = init();
222222
entry.insert(value.clone());
223223
new_entry = true;
224224
value
@@ -250,6 +250,25 @@ where
250250
value
251251
}
252252

253+
/// Removes the key from the map.
254+
/// TODO: Consider removing the key from the time bucket as well. We would need to know which
255+
/// bucket the key was in to do this. One idea is to store the bucket idx in the map value.
256+
pub fn remove(&self, key: K) {
257+
self.data.remove(&key);
258+
}
259+
260+
/// Returns the number of entries currently stored in the map
261+
#[cfg(test)]
262+
pub fn len(&self) -> usize {
263+
self.data.len()
264+
}
265+
266+
/// Returns an iterator over the keys currently stored in the map
267+
#[cfg(test)]
268+
pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
269+
self.data.iter().map(|entry| entry.key().clone())
270+
}
271+
253272
/// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
254273
/// function terminates if `shutdown` is signalled.
255274
async fn run_gc_loop(time: Arc<AtomicU64>, period: Duration, buckets: &[Bucket<K>]) {
@@ -349,6 +368,7 @@ mod tests {
349368
// All entries expired
350369
}
351370

371+
// assert_eventually checks a condition every 10ms for a maximum of timeout
352372
async fn assert_eventually<F>(assertion: F, timeout: Duration)
353373
where
354374
F: Fn() -> bool,
@@ -358,12 +378,12 @@ mod tests {
358378
if assertion() {
359379
return;
360380
}
361-
tokio::time::sleep(Duration::from_millis(50)).await;
381+
tokio::time::sleep(Duration::from_millis(10)).await;
362382
}
363383
panic!("Assertion failed within {:?}", timeout);
364384
}
365385

366-
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
386+
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
367387
async fn test_concurrent_gc_and_access() {
368388
let ttl_map = TTLMap::<String, i32>::try_new(TTLMapConfig {
369389
ttl: Duration::from_millis(10),
@@ -377,22 +397,30 @@ mod tests {
377397

378398
// Spawn 5 concurrent tasks
379399
let mut handles = Vec::new();
380-
for task_id in 0..5 {
400+
for task_id in 0..10 {
381401
let map = Arc::clone(&ttl_map);
382-
let handle = tokio::spawn(async move {
383-
for i in 0..20 {
384-
let key = format!("task{}_key{}", task_id, i % 4);
385-
map.get_or_init(key, || task_id * 100 + i);
386-
sleep(Duration::from_millis(1)).await;
402+
handles.push(tokio::spawn(async move {
403+
for i in 0..100 {
404+
let key = format!("task{}_key{}", task_id, i % 10);
405+
map.get_or_init(key.clone(), || task_id * 100 + i);
387406
}
388-
});
389-
handles.push(handle);
407+
}));
408+
let map2 = Arc::clone(&ttl_map);
409+
handles.push(tokio::spawn(async move {
410+
// Remove some keys which may or may not exist.
411+
for i in 0..50 {
412+
let key = format!("task{}_key{}", task_id, i % 15);
413+
map2.remove(key)
414+
}
415+
}));
390416
}
391417

392418
// Wait for all tasks to complete
393419
for handle in handles {
394420
handle.await.unwrap();
395421
}
422+
423+
assert_eventually(|| ttl_map.data.len() == 0, Duration::from_millis(20)).await;
396424
}
397425

398426
#[tokio::test]
@@ -475,4 +503,41 @@ mod tests {
475503
ttl_map.metrics.ttl_accounting_time.load(Ordering::SeqCst) / 1_000_000
476504
);
477505
}
506+
507+
#[tokio::test]
508+
async fn test_remove_with_manual_gc() {
509+
let ttl_map = TTLMap::<String, i32>::_new(TTLMapConfig {
510+
ttl: Duration::from_millis(50),
511+
tick: Duration::from_millis(10),
512+
});
513+
514+
ttl_map.get_or_init("key1".to_string(), || 100);
515+
ttl_map.get_or_init("key2".to_string(), || 200);
516+
ttl_map.get_or_init("key3".to_string(), || 300);
517+
assert_eq!(ttl_map.data.len(), 3);
518+
519+
// Remove key2 and verify the others remain.
520+
ttl_map.remove("key2".to_string());
521+
assert_eq!(ttl_map.data.len(), 2);
522+
let val1 = ttl_map.get_or_init("key1".to_string(), || 999);
523+
assert_eq!(val1, 100);
524+
let val3 = ttl_map.get_or_init("key3".to_string(), || 999);
525+
assert_eq!(val3, 300);
526+
527+
// key2 should be recreated with new value
528+
let val2 = ttl_map.get_or_init("key2".to_string(), || 999);
529+
assert_eq!(val2, 999); // New value since it was removed
530+
assert_eq!(ttl_map.data.len(), 3);
531+
let val3 = ttl_map.get_or_init("key2".to_string(), || 200);
532+
assert_eq!(val3, 999);
533+
534+
// Remove key1 before GCing.
535+
ttl_map.remove("key1".to_string());
536+
537+
// Run GC and verify the map is empty.
538+
for _ in 0..5 {
539+
TTLMap::<String, i32>::gc(ttl_map.time.clone(), &ttl_map.buckets);
540+
}
541+
assert_eventually(|| ttl_map.data.len() == 0, Duration::from_millis(100)).await;
542+
}
478543
}

0 commit comments

Comments
 (0)