From 46666b671c9b4ed7ee95e136fb5c61935f0e91c1 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Fri, 1 Aug 2025 18:39:49 -0400 Subject: [PATCH] delegation: add gc task to clean up abandoned contexts This change addresses a TODO in the `StageDelegation` struct. Previously, it was possible for `StageContext`s to leak if a context was produced and not consumed (or if a waiter created a map entry and timed out without any corresponding producer call). This change introduces a GC task to clear contexts from the map if they have been there for longer than a certain TTL. The TTL proposed is 2 * `wait_timeout`. When the `StageDelegation` is dropped, the GC task is terminated. Testing - Adds two unit tests (one for `add_delegate_info` and one for `wait_for_delegate_info`) which test that abandoned map entries are garbage collected appropriately. --- src/stage_delegation/delegation.rs | 202 ++++++++++++++++++++++++++--- 1 file changed, 184 insertions(+), 18 deletions(-) diff --git a/src/stage_delegation/delegation.rs b/src/stage_delegation/delegation.rs index 5d4f9e1..3901c2c 100644 --- a/src/stage_delegation/delegation.rs +++ b/src/stage_delegation/delegation.rs @@ -2,8 +2,13 @@ use super::StageContext; use dashmap::{DashMap, Entry}; use datafusion::common::{exec_datafusion_err, exec_err}; use datafusion::error::DataFusionError; -use std::time::Duration; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::oneshot; +use tokio::sync::Notify; +use tokio::time; +use tokio::time::Instant; /// In each stage of the distributed plan, there will be N workers. All these workers /// need to coordinate to pull data from the next stage, which will contain M workers. @@ -30,20 +35,70 @@ use tokio::sync::oneshot; /// On 2, the `wait_for_delegate_info` call will create an entry in the [DashMap] with a /// [oneshot::Sender], and listen on the other end of the channel [oneshot::Receiver] for /// the delegate to put something there. +/// +/// It's possible for [StageContext] to "get lost" if `add_delegate_info` is called without +/// a corresponding call to `wait_for_delegate_info` or vice versa. In this case, a task will +/// reap any contexts that live for longer than the `gc_ttl`. pub struct StageDelegation { - stage_targets: DashMap<(String, usize), Oneof>, + stage_targets: Arc>, wait_timeout: Duration, + + /// notify is used to shut down the garbage collection task when the StageDelegation is dropped. + notify: Arc, } impl Default for StageDelegation { fn default() -> Self { - Self { - stage_targets: DashMap::default(), + let stage_targets = Arc::new(DashMap::default()); + let notify = Arc::new(Notify::new()); + + let result = Self { + stage_targets: stage_targets.clone(), wait_timeout: Duration::from_secs(5), + notify: notify.clone(), + }; + + // Run the GC task. + tokio::spawn(run_gc( + stage_targets.clone(), + notify.clone(), + Duration::from_secs(30), /* gc period */ + )); + + result + } +} + +const GC_PERIOD_SECONDS: usize = 30; + +// run_gc will continuously clear expired entries from the map, checking every `period`. The +// function terminates if `shutdown` is signalled. +async fn run_gc( + stage_targets: Arc>, + shutdown: Arc, + period: Duration, +) { + loop { + tokio::select! { + _ = shutdown.notified() => { + break; + } + _ = tokio::time::sleep(period) => { + // Performance: This iterator is sharded, so it won't lock the whole map. + stage_targets.retain(|_key, value| { + value.expiry.gt(&Instant::now()) + }); + } } } } +impl Drop for StageDelegation { + fn drop(&mut self) { + self.notify.notify_one(); + } +} + impl StageDelegation { /// Puts the [StageContext] info so that an actor can pick it up with `wait_for_delegate_info`. /// @@ -57,9 +112,13 @@ impl StageDelegation { actor_idx: usize, next_stage_context: StageContext, ) -> Result<(), DataFusionError> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); let tx = match self.stage_targets.entry((stage_id, actor_idx)) { - Entry::Occupied(entry) => match entry.get() { - Oneof::Sender(_) => match entry.remove() { + Entry::Occupied(entry) => match entry.get().value { + Oneof::Sender(_) => match entry.remove().value { Oneof::Sender(tx) => tx, Oneof::Receiver(_) => unreachable!(), }, @@ -69,17 +128,14 @@ impl StageDelegation { }, Entry::Vacant(entry) => { let (tx, rx) = oneshot::channel(); - entry.insert(Oneof::Receiver(rx)); + entry.insert(Value { + expiry: Instant::now().add(self.gc_ttl()), + value: Oneof::Receiver(rx), + }); tx } }; - // TODO: `send` does not wait for the other end of the channel to receive the message, - // so if nobody waits for it, we might leak an entry in `stage_targets` that will never - // be cleaned up. We can either: - // 1. schedule a cleanup task that iterates the entries cleaning up old ones - // 2. find some other API that allows us to .await until the other end receives the message, - // and on a timeout, cleanup the entry anyway. tx.send(next_stage_context) .map_err(|_| exec_datafusion_err!("Could not send stage context info")) } @@ -95,16 +151,24 @@ impl StageDelegation { actor_idx: usize, ) -> Result { let rx = match self.stage_targets.entry((stage_id.clone(), actor_idx)) { - Entry::Occupied(entry) => match entry.get() { - Oneof::Sender(_) => return exec_err!("Programming error: while waiting for delegate info the entry in the StageDelegation target map cannot be a Sender"), - Oneof::Receiver(_) => match entry.remove() { + Entry::Occupied(entry) => match entry.get().value { + Oneof::Sender(_) => { + return exec_err!( + "Programming error: while waiting for delegate info the entry in the \ + StageDelegation target map cannot be a Sender" + ) + } + Oneof::Receiver(_) => match entry.remove().value { Oneof::Sender(_) => unreachable!(), - Oneof::Receiver(rx) => rx + Oneof::Receiver(rx) => rx, }, }, Entry::Vacant(entry) => { let (tx, rx) = oneshot::channel(); - entry.insert(Oneof::Sender(tx)); + entry.insert(Value { + expiry: Instant::now().add(self.gc_ttl()), + value: Oneof::Sender(tx), + }); rx } }; @@ -118,6 +182,17 @@ impl StageDelegation { ) }) } + + // gc_ttl is used to set the expiry of elements in the map. Use 2 * the waiter wait duration + // to avoid running gc too early. + fn gc_ttl(&self) -> Duration { + self.wait_timeout * 2 + } +} + +struct Value { + expiry: Instant, + value: Oneof, } enum Oneof { @@ -129,6 +204,7 @@ enum Oneof { mod tests { use super::*; use crate::stage_delegation::StageContext; + use futures::TryFutureExt; use std::sync::Arc; use uuid::Uuid; @@ -222,6 +298,7 @@ mod tests { let received_context = wait_task1.await.unwrap().unwrap(); assert_eq!(received_context.id, stage_context.id); + assert_eq!(0, delegation.stage_targets.len()) } #[tokio::test] @@ -287,4 +364,93 @@ mod tests { .unwrap(); assert_eq!(received_context, stage_context); } + + #[tokio::test] + async fn test_waiter_timeout_and_gc_cleanup() { + let stage_targets = Arc::new(DashMap::default()); + let shutdown = Arc::new(Notify::new()); + let delegation = StageDelegation { + stage_targets: stage_targets.clone(), + wait_timeout: Duration::from_millis(1), + notify: shutdown.clone(), + }; + let stage_id = Uuid::new_v4().to_string(); + + // Actor waits but times out + let result = delegation.wait_for_delegate_info(stage_id, 0).await; + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Timeout")); + + // Wait for expiry time to pass. + tokio::time::sleep(delegation.gc_ttl()).await; + + // Run GC to clean up expired entries + let gc_task = tokio::spawn(run_gc( + stage_targets.clone(), + shutdown.clone(), + Duration::from_millis(5), + )); + + // Wait for GC to clear the map + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if stage_targets.len() == 0 { + break; + } + } + + // Stop GC by dropping the delegation. Assert that it has shutdown. + drop(delegation); + gc_task.await.unwrap(); + + // After GC, map should be cleared. + assert_eq!(stage_targets.len(), 0); + } + + #[tokio::test] + async fn test_writer_only_and_gc_cleanup() { + let stage_targets = Arc::new(DashMap::default()); + let shutdown = Arc::new(Notify::new()); + let delegation = StageDelegation { + stage_targets: stage_targets.clone(), + wait_timeout: Duration::from_millis(1), + notify: shutdown.clone(), + }; + let stage_id = Uuid::new_v4().to_string(); + let stage_context = create_test_stage_context(); + + // Writer adds info without anyone waiting + let result = delegation.add_delegate_info(stage_id, 0, stage_context); + + assert!(result.is_ok()); + + // Entry should be in map + assert_eq!(stage_targets.len(), 1); + + // Wait for expiry time to pass (gc_ttl is 2 * wait_timeout) + tokio::time::sleep(delegation.gc_ttl()).await; + + // Run GC to cleanup expired entries + let gc_task = tokio::spawn(run_gc( + stage_targets.clone(), + shutdown.clone(), + Duration::from_millis(10), + )); + + // Wait for GC to clear the map + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(20)).await; + if stage_targets.len() == 0 { + break; + } + } + + // Stop GC. + drop(delegation); + gc_task.await.unwrap(); + + // After GC, map should be cleared + assert_eq!(stage_targets.len(), 0); + } }