Skip to content

Commit 46666b6

Browse files
delegation: add gc task to clean up abandoned contexts
This change addresses a TODO in the `StageDelegation` struct. Previously, it was possible for `StageContext`s to leak if a context was produced and not consumed (or if a waiter created a map entry and timed out without any corresponding producer call). This change introduces a GC task to clear contexts from the map if they have been there for longer than a certain TTL. The TTL proposed is 2 * `wait_timeout`. When the `StageDelegation` is dropped, the GC task is terminated. Testing - Adds two unit tests (one for `add_delegate_info` and one for `wait_for_delegate_info`) which test that abandoned map entries are garbage collected appropriately.
1 parent bd9e61f commit 46666b6

File tree

1 file changed

+184
-18
lines changed

1 file changed

+184
-18
lines changed

src/stage_delegation/delegation.rs

Lines changed: 184 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@ use super::StageContext;
22
use dashmap::{DashMap, Entry};
33
use datafusion::common::{exec_datafusion_err, exec_err};
44
use datafusion::error::DataFusionError;
5-
use std::time::Duration;
5+
use std::ops::Add;
6+
use std::sync::Arc;
7+
use std::time::{Duration, SystemTime, UNIX_EPOCH};
68
use tokio::sync::oneshot;
9+
use tokio::sync::Notify;
10+
use tokio::time;
11+
use tokio::time::Instant;
712

813
/// In each stage of the distributed plan, there will be N workers. All these workers
914
/// need to coordinate to pull data from the next stage, which will contain M workers.
@@ -30,20 +35,70 @@ use tokio::sync::oneshot;
3035
/// On 2, the `wait_for_delegate_info` call will create an entry in the [DashMap] with a
3136
/// [oneshot::Sender], and listen on the other end of the channel [oneshot::Receiver] for
3237
/// the delegate to put something there.
38+
///
39+
/// It's possible for [StageContext] to "get lost" if `add_delegate_info` is called without
40+
/// a corresponding call to `wait_for_delegate_info` or vice versa. In this case, a task will
41+
/// reap any contexts that live for longer than the `gc_ttl`.
3342
pub struct StageDelegation {
34-
stage_targets: DashMap<(String, usize), Oneof>,
43+
stage_targets: Arc<DashMap<(String, usize), Value>>,
3544
wait_timeout: Duration,
45+
46+
/// notify is used to shut down the garbage collection task when the StageDelegation is dropped.
47+
notify: Arc<Notify>,
3648
}
3749

3850
impl Default for StageDelegation {
3951
fn default() -> Self {
40-
Self {
41-
stage_targets: DashMap::default(),
52+
let stage_targets = Arc::new(DashMap::default());
53+
let notify = Arc::new(Notify::new());
54+
55+
let result = Self {
56+
stage_targets: stage_targets.clone(),
4257
wait_timeout: Duration::from_secs(5),
58+
notify: notify.clone(),
59+
};
60+
61+
// Run the GC task.
62+
tokio::spawn(run_gc(
63+
stage_targets.clone(),
64+
notify.clone(),
65+
Duration::from_secs(30), /* gc period */
66+
));
67+
68+
result
69+
}
70+
}
71+
72+
const GC_PERIOD_SECONDS: usize = 30;
73+
74+
// run_gc will continuously clear expired entries from the map, checking every `period`. The
75+
// function terminates if `shutdown` is signalled.
76+
async fn run_gc(
77+
stage_targets: Arc<DashMap<(String, usize), Value>>,
78+
shutdown: Arc<Notify>,
79+
period: Duration,
80+
) {
81+
loop {
82+
tokio::select! {
83+
_ = shutdown.notified() => {
84+
break;
85+
}
86+
_ = tokio::time::sleep(period) => {
87+
// Performance: This iterator is sharded, so it won't lock the whole map.
88+
stage_targets.retain(|_key, value| {
89+
value.expiry.gt(&Instant::now())
90+
});
91+
}
4392
}
4493
}
4594
}
4695

96+
impl Drop for StageDelegation {
97+
fn drop(&mut self) {
98+
self.notify.notify_one();
99+
}
100+
}
101+
47102
impl StageDelegation {
48103
/// Puts the [StageContext] info so that an actor can pick it up with `wait_for_delegate_info`.
49104
///
@@ -57,9 +112,13 @@ impl StageDelegation {
57112
actor_idx: usize,
58113
next_stage_context: StageContext,
59114
) -> Result<(), DataFusionError> {
115+
let now = SystemTime::now()
116+
.duration_since(UNIX_EPOCH)
117+
.unwrap()
118+
.as_secs();
60119
let tx = match self.stage_targets.entry((stage_id, actor_idx)) {
61-
Entry::Occupied(entry) => match entry.get() {
62-
Oneof::Sender(_) => match entry.remove() {
120+
Entry::Occupied(entry) => match entry.get().value {
121+
Oneof::Sender(_) => match entry.remove().value {
63122
Oneof::Sender(tx) => tx,
64123
Oneof::Receiver(_) => unreachable!(),
65124
},
@@ -69,17 +128,14 @@ impl StageDelegation {
69128
},
70129
Entry::Vacant(entry) => {
71130
let (tx, rx) = oneshot::channel();
72-
entry.insert(Oneof::Receiver(rx));
131+
entry.insert(Value {
132+
expiry: Instant::now().add(self.gc_ttl()),
133+
value: Oneof::Receiver(rx),
134+
});
73135
tx
74136
}
75137
};
76138

77-
// TODO: `send` does not wait for the other end of the channel to receive the message,
78-
// so if nobody waits for it, we might leak an entry in `stage_targets` that will never
79-
// be cleaned up. We can either:
80-
// 1. schedule a cleanup task that iterates the entries cleaning up old ones
81-
// 2. find some other API that allows us to .await until the other end receives the message,
82-
// and on a timeout, cleanup the entry anyway.
83139
tx.send(next_stage_context)
84140
.map_err(|_| exec_datafusion_err!("Could not send stage context info"))
85141
}
@@ -95,16 +151,24 @@ impl StageDelegation {
95151
actor_idx: usize,
96152
) -> Result<StageContext, DataFusionError> {
97153
let rx = match self.stage_targets.entry((stage_id.clone(), actor_idx)) {
98-
Entry::Occupied(entry) => match entry.get() {
99-
Oneof::Sender(_) => return exec_err!("Programming error: while waiting for delegate info the entry in the StageDelegation target map cannot be a Sender"),
100-
Oneof::Receiver(_) => match entry.remove() {
154+
Entry::Occupied(entry) => match entry.get().value {
155+
Oneof::Sender(_) => {
156+
return exec_err!(
157+
"Programming error: while waiting for delegate info the entry in the \
158+
StageDelegation target map cannot be a Sender"
159+
)
160+
}
161+
Oneof::Receiver(_) => match entry.remove().value {
101162
Oneof::Sender(_) => unreachable!(),
102-
Oneof::Receiver(rx) => rx
163+
Oneof::Receiver(rx) => rx,
103164
},
104165
},
105166
Entry::Vacant(entry) => {
106167
let (tx, rx) = oneshot::channel();
107-
entry.insert(Oneof::Sender(tx));
168+
entry.insert(Value {
169+
expiry: Instant::now().add(self.gc_ttl()),
170+
value: Oneof::Sender(tx),
171+
});
108172
rx
109173
}
110174
};
@@ -118,6 +182,17 @@ impl StageDelegation {
118182
)
119183
})
120184
}
185+
186+
// gc_ttl is used to set the expiry of elements in the map. Use 2 * the waiter wait duration
187+
// to avoid running gc too early.
188+
fn gc_ttl(&self) -> Duration {
189+
self.wait_timeout * 2
190+
}
191+
}
192+
193+
struct Value {
194+
expiry: Instant,
195+
value: Oneof,
121196
}
122197

123198
enum Oneof {
@@ -129,6 +204,7 @@ enum Oneof {
129204
mod tests {
130205
use super::*;
131206
use crate::stage_delegation::StageContext;
207+
use futures::TryFutureExt;
132208
use std::sync::Arc;
133209
use uuid::Uuid;
134210

@@ -222,6 +298,7 @@ mod tests {
222298

223299
let received_context = wait_task1.await.unwrap().unwrap();
224300
assert_eq!(received_context.id, stage_context.id);
301+
assert_eq!(0, delegation.stage_targets.len())
225302
}
226303

227304
#[tokio::test]
@@ -287,4 +364,93 @@ mod tests {
287364
.unwrap();
288365
assert_eq!(received_context, stage_context);
289366
}
367+
368+
#[tokio::test]
369+
async fn test_waiter_timeout_and_gc_cleanup() {
370+
let stage_targets = Arc::new(DashMap::default());
371+
let shutdown = Arc::new(Notify::new());
372+
let delegation = StageDelegation {
373+
stage_targets: stage_targets.clone(),
374+
wait_timeout: Duration::from_millis(1),
375+
notify: shutdown.clone(),
376+
};
377+
let stage_id = Uuid::new_v4().to_string();
378+
379+
// Actor waits but times out
380+
let result = delegation.wait_for_delegate_info(stage_id, 0).await;
381+
382+
assert!(result.is_err());
383+
assert!(result.unwrap_err().to_string().contains("Timeout"));
384+
385+
// Wait for expiry time to pass.
386+
tokio::time::sleep(delegation.gc_ttl()).await;
387+
388+
// Run GC to clean up expired entries
389+
let gc_task = tokio::spawn(run_gc(
390+
stage_targets.clone(),
391+
shutdown.clone(),
392+
Duration::from_millis(5),
393+
));
394+
395+
// Wait for GC to clear the map
396+
for _ in 0..10 {
397+
tokio::time::sleep(Duration::from_millis(10)).await;
398+
if stage_targets.len() == 0 {
399+
break;
400+
}
401+
}
402+
403+
// Stop GC by dropping the delegation. Assert that it has shutdown.
404+
drop(delegation);
405+
gc_task.await.unwrap();
406+
407+
// After GC, map should be cleared.
408+
assert_eq!(stage_targets.len(), 0);
409+
}
410+
411+
#[tokio::test]
412+
async fn test_writer_only_and_gc_cleanup() {
413+
let stage_targets = Arc::new(DashMap::default());
414+
let shutdown = Arc::new(Notify::new());
415+
let delegation = StageDelegation {
416+
stage_targets: stage_targets.clone(),
417+
wait_timeout: Duration::from_millis(1),
418+
notify: shutdown.clone(),
419+
};
420+
let stage_id = Uuid::new_v4().to_string();
421+
let stage_context = create_test_stage_context();
422+
423+
// Writer adds info without anyone waiting
424+
let result = delegation.add_delegate_info(stage_id, 0, stage_context);
425+
426+
assert!(result.is_ok());
427+
428+
// Entry should be in map
429+
assert_eq!(stage_targets.len(), 1);
430+
431+
// Wait for expiry time to pass (gc_ttl is 2 * wait_timeout)
432+
tokio::time::sleep(delegation.gc_ttl()).await;
433+
434+
// Run GC to cleanup expired entries
435+
let gc_task = tokio::spawn(run_gc(
436+
stage_targets.clone(),
437+
shutdown.clone(),
438+
Duration::from_millis(10),
439+
));
440+
441+
// Wait for GC to clear the map
442+
for _ in 0..10 {
443+
tokio::time::sleep(Duration::from_millis(20)).await;
444+
if stage_targets.len() == 0 {
445+
break;
446+
}
447+
}
448+
449+
// Stop GC.
450+
drop(delegation);
451+
gc_task.await.unwrap();
452+
453+
// After GC, map should be cleared
454+
assert_eq!(stage_targets.len(), 0);
455+
}
290456
}

0 commit comments

Comments
 (0)