Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 79 additions & 14 deletions src/common/ttl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ where
/// new creates a new Bucket
fn new<V>(data: Arc<DashMap<K, V>>) -> Self
where
V: Default + Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
{
// TODO: To avoid unbounded growth, consider making this bounded. Alternatively, we can
// introduce a dynamic GC period to ensure that GC can keep up.
Expand Down Expand Up @@ -84,7 +84,7 @@ where
/// task is responsible for managing a subset of keys in the map.
async fn task<V>(mut rx: UnboundedReceiver<BucketOp<K>>, data: Arc<DashMap<K, V>>)
where
V: Default + Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
{
let mut shard = HashSet::new();
while let Some(op) = rx.recv().await {
Expand Down Expand Up @@ -207,7 +207,7 @@ where

/// get_or_default executes the provided closure with a reference to the map entry for the given key.
/// If the key does not exist, it inserts a new entry with the default value.
pub fn get_or_init<F>(&self, key: K, f: F) -> V
pub fn get_or_init<F>(&self, key: K, init: F) -> V
where
F: FnOnce() -> V,
{
Expand All @@ -218,7 +218,7 @@ where

let value = match self.data.entry(key.clone()) {
Entry::Vacant(entry) => {
let value = f();
let value = init();
entry.insert(value.clone());
new_entry = true;
value
Expand Down Expand Up @@ -250,6 +250,25 @@ where
value
}

/// Removes the key from the map.
/// TODO: Consider removing the key from the time bucket as well. We would need to know which
/// bucket the key was in to do this. One idea is to store the bucket idx in the map value.
pub fn remove(&self, key: K) {
self.data.remove(&key);
}

/// Returns the number of entries currently stored in the map
#[cfg(test)]
pub fn len(&self) -> usize {
self.data.len()
}

/// Returns an iterator over the keys currently stored in the map
#[cfg(test)]
pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
self.data.iter().map(|entry| entry.key().clone())
}

/// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
/// function terminates if `shutdown` is signalled.
async fn run_gc_loop(time: Arc<AtomicU64>, period: Duration, buckets: &[Bucket<K>]) {
Expand Down Expand Up @@ -349,6 +368,7 @@ mod tests {
// All entries expired
}

// assert_eventually checks a condition every 10ms for a maximum of timeout
async fn assert_eventually<F>(assertion: F, timeout: Duration)
where
F: Fn() -> bool,
Expand All @@ -358,12 +378,12 @@ mod tests {
if assertion() {
return;
}
tokio::time::sleep(Duration::from_millis(50)).await;
tokio::time::sleep(Duration::from_millis(10)).await;
}
panic!("Assertion failed within {:?}", timeout);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn test_concurrent_gc_and_access() {
let ttl_map = TTLMap::<String, i32>::try_new(TTLMapConfig {
ttl: Duration::from_millis(10),
Expand All @@ -377,22 +397,30 @@ mod tests {

// Spawn 5 concurrent tasks
let mut handles = Vec::new();
for task_id in 0..5 {
for task_id in 0..10 {
let map = Arc::clone(&ttl_map);
let handle = tokio::spawn(async move {
for i in 0..20 {
let key = format!("task{}_key{}", task_id, i % 4);
map.get_or_init(key, || task_id * 100 + i);
sleep(Duration::from_millis(1)).await;
handles.push(tokio::spawn(async move {
for i in 0..100 {
let key = format!("task{}_key{}", task_id, i % 10);
map.get_or_init(key.clone(), || task_id * 100 + i);
}
});
handles.push(handle);
}));
let map2 = Arc::clone(&ttl_map);
handles.push(tokio::spawn(async move {
// Remove some keys which may or may not exist.
for i in 0..50 {
let key = format!("task{}_key{}", task_id, i % 15);
map2.remove(key)
}
}));
}

// Wait for all tasks to complete
for handle in handles {
handle.await.unwrap();
}

assert_eventually(|| ttl_map.data.len() == 0, Duration::from_millis(20)).await;
}

#[tokio::test]
Expand Down Expand Up @@ -475,4 +503,41 @@ mod tests {
ttl_map.metrics.ttl_accounting_time.load(Ordering::SeqCst) / 1_000_000
);
}

#[tokio::test]
async fn test_remove_with_manual_gc() {
let ttl_map = TTLMap::<String, i32>::_new(TTLMapConfig {
ttl: Duration::from_millis(50),
tick: Duration::from_millis(10),
});

ttl_map.get_or_init("key1".to_string(), || 100);
ttl_map.get_or_init("key2".to_string(), || 200);
ttl_map.get_or_init("key3".to_string(), || 300);
assert_eq!(ttl_map.data.len(), 3);

// Remove key2 and verify the others remain.
ttl_map.remove("key2".to_string());
assert_eq!(ttl_map.data.len(), 2);
let val1 = ttl_map.get_or_init("key1".to_string(), || 999);
assert_eq!(val1, 100);
let val3 = ttl_map.get_or_init("key3".to_string(), || 999);
assert_eq!(val3, 300);

// key2 should be recreated with new value
let val2 = ttl_map.get_or_init("key2".to_string(), || 999);
assert_eq!(val2, 999); // New value since it was removed
assert_eq!(ttl_map.data.len(), 3);
let val3 = ttl_map.get_or_init("key2".to_string(), || 200);
assert_eq!(val3, 999);

// Remove key1 before GCing.
ttl_map.remove("key1".to_string());

// Run GC and verify the map is empty.
for _ in 0..5 {
TTLMap::<String, i32>::gc(ttl_map.time.clone(), &ttl_map.buckets);
}
assert_eventually(|| ttl_map.data.len() == 0, Duration::from_millis(100)).await;
}
}
Loading