@@ -250,6 +250,13 @@ where
250250 value
251251 }
252252
253+ /// Removes the key from the map.
254+ /// TODO: Consider removing the key from the time bucket as well. We would need to know which
255+ /// bucket the key was in to do this. One idea is to store the bucket idx in the map value.
256+ pub fn remove ( & self , key : K ) {
257+ self . data . remove ( & key) ;
258+ }
259+
253260 /// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
254261 /// function terminates if `shutdown` is signalled.
255262 async fn run_gc_loop ( time : Arc < AtomicU64 > , period : Duration , buckets : & [ Bucket < K > ] ) {
@@ -349,6 +356,7 @@ mod tests {
349356 // All entries expired
350357 }
351358
359+ // assert_eventually checks a condition every 10ms for a maximum of timeout
352360 async fn assert_eventually < F > ( assertion : F , timeout : Duration )
353361 where
354362 F : Fn ( ) -> bool ,
@@ -358,12 +366,12 @@ mod tests {
358366 if assertion ( ) {
359367 return ;
360368 }
361- tokio:: time:: sleep ( Duration :: from_millis ( 50 ) ) . await ;
369+ tokio:: time:: sleep ( Duration :: from_millis ( 10 ) ) . await ;
362370 }
363371 panic ! ( "Assertion failed within {:?}" , timeout) ;
364372 }
365373
366- #[ tokio:: test( flavor = "multi_thread" , worker_threads = 4 ) ]
374+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 8 ) ]
367375 async fn test_concurrent_gc_and_access ( ) {
368376 let ttl_map = TTLMap :: < String , i32 > :: try_new ( TTLMapConfig {
369377 ttl : Duration :: from_millis ( 10 ) ,
@@ -377,22 +385,30 @@ mod tests {
377385
378386 // Spawn 5 concurrent tasks
379387 let mut handles = Vec :: new ( ) ;
380- for task_id in 0 ..5 {
388+ for task_id in 0 ..10 {
381389 let map = Arc :: clone ( & ttl_map) ;
382- let handle = tokio:: spawn ( async move {
383- for i in 0 ..20 {
384- let key = format ! ( "task{}_key{}" , task_id, i % 4 ) ;
385- map. get_or_init ( key, || task_id * 100 + i) ;
386- sleep ( Duration :: from_millis ( 1 ) ) . await ;
390+ handles. push ( tokio:: spawn ( async move {
391+ for i in 0 ..100 {
392+ let key = format ! ( "task{}_key{}" , task_id, i % 10 ) ;
393+ map. get_or_init ( key. clone ( ) , || task_id * 100 + i) ;
387394 }
388- } ) ;
389- handles. push ( handle) ;
395+ } ) ) ;
396+ let map2 = Arc :: clone ( & ttl_map) ;
397+ handles. push ( tokio:: spawn ( async move {
398+ // Remove some keys which may or may not exist.
399+ for i in 0 ..50 {
400+ let key = format ! ( "task{}_key{}" , task_id, i % 15 ) ;
401+ map2. remove ( key)
402+ }
403+ } ) ) ;
390404 }
391405
392406 // Wait for all tasks to complete
393407 for handle in handles {
394408 handle. await . unwrap ( ) ;
395409 }
410+
411+ assert_eventually ( || ttl_map. data . len ( ) == 0 , Duration :: from_millis ( 20 ) ) . await ;
396412 }
397413
398414 #[ tokio:: test]
@@ -475,4 +491,41 @@ mod tests {
475491 ttl_map. metrics. ttl_accounting_time. load( Ordering :: SeqCst ) / 1_000_000
476492 ) ;
477493 }
494+
495+ #[ tokio:: test]
496+ async fn test_remove_with_manual_gc ( ) {
497+ let ttl_map = TTLMap :: < String , i32 > :: _new ( TTLMapConfig {
498+ ttl : Duration :: from_millis ( 50 ) ,
499+ tick : Duration :: from_millis ( 10 ) ,
500+ } ) ;
501+
502+ ttl_map. get_or_init ( "key1" . to_string ( ) , || 100 ) ;
503+ ttl_map. get_or_init ( "key2" . to_string ( ) , || 200 ) ;
504+ ttl_map. get_or_init ( "key3" . to_string ( ) , || 300 ) ;
505+ assert_eq ! ( ttl_map. data. len( ) , 3 ) ;
506+
507+ // Remove key2 and verify the others remain.
508+ ttl_map. remove ( "key2" . to_string ( ) ) ;
509+ assert_eq ! ( ttl_map. data. len( ) , 2 ) ;
510+ let val1 = ttl_map. get_or_init ( "key1" . to_string ( ) , || 999 ) ;
511+ assert_eq ! ( val1, 100 ) ;
512+ let val3 = ttl_map. get_or_init ( "key3" . to_string ( ) , || 999 ) ;
513+ assert_eq ! ( val3, 300 ) ;
514+
515+ // key2 should be recreated with new value
516+ let val2 = ttl_map. get_or_init ( "key2" . to_string ( ) , || 999 ) ;
517+ assert_eq ! ( val2, 999 ) ; // New value since it was removed
518+ assert_eq ! ( ttl_map. data. len( ) , 3 ) ;
519+ let val3 = ttl_map. get_or_init ( "key2" . to_string ( ) , || 200 ) ;
520+ assert_eq ! ( val3, 999 ) ;
521+
522+ // Remove key1 before GCing.
523+ ttl_map. remove ( "key1" . to_string ( ) ) ;
524+
525+ // Run GC and verify the map is empty.
526+ for _ in 0 ..5 {
527+ TTLMap :: < String , i32 > :: gc ( ttl_map. time . clone ( ) , & ttl_map. buckets ) ;
528+ }
529+ assert_eventually ( || ttl_map. data . len ( ) == 0 , Duration :: from_millis ( 100 ) ) . await ;
530+ }
478531}
0 commit comments