Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 184 additions & 18 deletions src/stage_delegation/delegation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ use super::StageContext;
use dashmap::{DashMap, Entry};
use datafusion::common::{exec_datafusion_err, exec_err};
use datafusion::error::DataFusionError;
use std::time::Duration;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::oneshot;
use tokio::sync::Notify;
use tokio::time;
use tokio::time::Instant;

/// In each stage of the distributed plan, there will be N workers. All these workers
/// need to coordinate to pull data from the next stage, which will contain M workers.
Expand All @@ -30,20 +35,70 @@ use tokio::sync::oneshot;
/// On 2, the `wait_for_delegate_info` call will create an entry in the [DashMap] with a
/// [oneshot::Sender], and listen on the other end of the channel [oneshot::Receiver] for
/// the delegate to put something there.
///
/// It's possible for [StageContext] to "get lost" if `add_delegate_info` is called without
/// a corresponding call to `wait_for_delegate_info` or vice versa. In this case, a task will
/// reap any contexts that live for longer than the `gc_ttl`.
pub struct StageDelegation {
stage_targets: DashMap<(String, usize), Oneof>,
stage_targets: Arc<DashMap<(String, usize), Value>>,
wait_timeout: Duration,

/// notify is used to shut down the garbage collection task when the StageDelegation is dropped.
notify: Arc<Notify>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be no need for a Notify here, there's already tooling available in the tokio and DataFusion ecosystem for cancelling tasks on Drop, for example you could simply do

pub struct StageDelegation {
    stage_targets: Arc<DashMap<(String, usize), Value>>,
    wait_timeout: Duration,
    /// notify is used to shut down the garbage collection task when the StageDelegation is dropped.
    _task: SpawnedTask<()>,
}

impl Default for StageDelegation {
    fn default() -> Self {
        let stage_targets = Arc::new(DashMap::default());
        Self {
            stage_targets: stage_targets.clone(),
            wait_timeout: Duration::from_secs(5),
            _task: SpawnedTask::spawn(run_gc(
                stage_targets.clone(),
                Duration::from_secs(30), /* gc period */
            )),
        }
    }
}

// run_gc will continuously clear expired entries from the map, checking every `period`. The
// function terminates if `shutdown` is signalled.
async fn run_gc(stage_targets: Arc<DashMap<(String, usize), Value>>, period: Duration) {
    loop {
        tokio::time::sleep(period).await;
        stage_targets.retain(|_key, value| value.expiry.gt(&Instant::now()));
    }
}

And it should work the same

}

impl Default for StageDelegation {
fn default() -> Self {
Self {
stage_targets: DashMap::default(),
let stage_targets = Arc::new(DashMap::default());
let notify = Arc::new(Notify::new());

let result = Self {
stage_targets: stage_targets.clone(),
wait_timeout: Duration::from_secs(5),
notify: notify.clone(),
};

// Run the GC task.
tokio::spawn(run_gc(
stage_targets.clone(),
notify.clone(),
Duration::from_secs(30), /* gc period */
));

result
}
}

const GC_PERIOD_SECONDS: usize = 30;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this const is unused?


// run_gc will continuously clear expired entries from the map, checking every `period`. The
// function terminates if `shutdown` is signalled.
async fn run_gc(
stage_targets: Arc<DashMap<(String, usize), Value>>,
shutdown: Arc<Notify>,
period: Duration,
) {
loop {
tokio::select! {
_ = shutdown.notified() => {
break;
}
_ = tokio::time::sleep(period) => {
// Performance: This iterator is sharded, so it won't lock the whole map.
stage_targets.retain(|_key, value| {
value.expiry.gt(&Instant::now())
});
Comment on lines +87 to +90
Copy link
Collaborator

@gabotechs gabotechs Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.retain() iterates the whole map, so it do will end up locking the whole map unfortunately.

I think there should be ways of just garbage collecting abandoned tasks without iterating the full map, maybe spawning a task with a big tokio::sleep() at the beginning for each for each added entry that gets either cancelled due to a drop (the task was not abandoned) or completely executed included the tokio::sleep (the task was abandoned)

}
}
}
}

impl Drop for StageDelegation {
fn drop(&mut self) {
self.notify.notify_one();
}
}

impl StageDelegation {
/// Puts the [StageContext] info so that an actor can pick it up with `wait_for_delegate_info`.
///
Expand All @@ -57,9 +112,13 @@ impl StageDelegation {
actor_idx: usize,
next_stage_context: StageContext,
) -> Result<(), DataFusionError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let tx = match self.stage_targets.entry((stage_id, actor_idx)) {
Entry::Occupied(entry) => match entry.get() {
Oneof::Sender(_) => match entry.remove() {
Entry::Occupied(entry) => match entry.get().value {
Oneof::Sender(_) => match entry.remove().value {
Oneof::Sender(tx) => tx,
Oneof::Receiver(_) => unreachable!(),
},
Expand All @@ -69,17 +128,14 @@ impl StageDelegation {
},
Entry::Vacant(entry) => {
let (tx, rx) = oneshot::channel();
entry.insert(Oneof::Receiver(rx));
entry.insert(Value {
expiry: Instant::now().add(self.gc_ttl()),
value: Oneof::Receiver(rx),
});
tx
}
};

// TODO: `send` does not wait for the other end of the channel to receive the message,
// so if nobody waits for it, we might leak an entry in `stage_targets` that will never
// be cleaned up. We can either:
// 1. schedule a cleanup task that iterates the entries cleaning up old ones
// 2. find some other API that allows us to .await until the other end receives the message,
// and on a timeout, cleanup the entry anyway.
tx.send(next_stage_context)
.map_err(|_| exec_datafusion_err!("Could not send stage context info"))
}
Expand All @@ -95,16 +151,24 @@ impl StageDelegation {
actor_idx: usize,
) -> Result<StageContext, DataFusionError> {
let rx = match self.stage_targets.entry((stage_id.clone(), actor_idx)) {
Entry::Occupied(entry) => match entry.get() {
Oneof::Sender(_) => return exec_err!("Programming error: while waiting for delegate info the entry in the StageDelegation target map cannot be a Sender"),
Oneof::Receiver(_) => match entry.remove() {
Entry::Occupied(entry) => match entry.get().value {
Oneof::Sender(_) => {
return exec_err!(
"Programming error: while waiting for delegate info the entry in the \
StageDelegation target map cannot be a Sender"
)
}
Oneof::Receiver(_) => match entry.remove().value {
Oneof::Sender(_) => unreachable!(),
Oneof::Receiver(rx) => rx
Oneof::Receiver(rx) => rx,
},
},
Entry::Vacant(entry) => {
let (tx, rx) = oneshot::channel();
entry.insert(Oneof::Sender(tx));
entry.insert(Value {
expiry: Instant::now().add(self.gc_ttl()),
value: Oneof::Sender(tx),
});
rx
}
};
Expand All @@ -118,6 +182,17 @@ impl StageDelegation {
)
})
}

// gc_ttl is used to set the expiry of elements in the map. Use 2 * the waiter wait duration
// to avoid running gc too early.
fn gc_ttl(&self) -> Duration {
self.wait_timeout * 2
}
}

struct Value {
expiry: Instant,
value: Oneof,
}

enum Oneof {
Expand All @@ -129,6 +204,7 @@ enum Oneof {
mod tests {
use super::*;
use crate::stage_delegation::StageContext;
use futures::TryFutureExt;
use std::sync::Arc;
use uuid::Uuid;

Expand Down Expand Up @@ -222,6 +298,7 @@ mod tests {

let received_context = wait_task1.await.unwrap().unwrap();
assert_eq!(received_context.id, stage_context.id);
assert_eq!(0, delegation.stage_targets.len())
}

#[tokio::test]
Expand Down Expand Up @@ -287,4 +364,93 @@ mod tests {
.unwrap();
assert_eq!(received_context, stage_context);
}

#[tokio::test]
async fn test_waiter_timeout_and_gc_cleanup() {
let stage_targets = Arc::new(DashMap::default());
let shutdown = Arc::new(Notify::new());
let delegation = StageDelegation {
stage_targets: stage_targets.clone(),
wait_timeout: Duration::from_millis(1),
notify: shutdown.clone(),
};
let stage_id = Uuid::new_v4().to_string();

// Actor waits but times out
let result = delegation.wait_for_delegate_info(stage_id, 0).await;

assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Timeout"));

// Wait for expiry time to pass.
tokio::time::sleep(delegation.gc_ttl()).await;

// Run GC to clean up expired entries
let gc_task = tokio::spawn(run_gc(
stage_targets.clone(),
shutdown.clone(),
Duration::from_millis(5),
));

// Wait for GC to clear the map
for _ in 0..10 {
tokio::time::sleep(Duration::from_millis(10)).await;
if stage_targets.len() == 0 {
break;
}
}

// Stop GC by dropping the delegation. Assert that it has shutdown.
drop(delegation);
gc_task.await.unwrap();

// After GC, map should be cleared.
assert_eq!(stage_targets.len(), 0);
}

#[tokio::test]
async fn test_writer_only_and_gc_cleanup() {
let stage_targets = Arc::new(DashMap::default());
let shutdown = Arc::new(Notify::new());
let delegation = StageDelegation {
stage_targets: stage_targets.clone(),
wait_timeout: Duration::from_millis(1),
notify: shutdown.clone(),
};
let stage_id = Uuid::new_v4().to_string();
let stage_context = create_test_stage_context();

// Writer adds info without anyone waiting
let result = delegation.add_delegate_info(stage_id, 0, stage_context);

assert!(result.is_ok());

// Entry should be in map
assert_eq!(stage_targets.len(), 1);

// Wait for expiry time to pass (gc_ttl is 2 * wait_timeout)
tokio::time::sleep(delegation.gc_ttl()).await;

// Run GC to cleanup expired entries
let gc_task = tokio::spawn(run_gc(
stage_targets.clone(),
shutdown.clone(),
Duration::from_millis(10),
));

// Wait for GC to clear the map
for _ in 0..10 {
tokio::time::sleep(Duration::from_millis(20)).await;
if stage_targets.len() == 0 {
break;
}
}

// Stop GC.
drop(delegation);
gc_task.await.unwrap();

// After GC, map should be cleared
assert_eq!(stage_targets.len(), 0);
}
}