Skip to content

Commit 8620f61

Browse files
asdf
1 parent bd9e61f commit 8620f61

File tree

1 file changed

+177
-12
lines changed

1 file changed

+177
-12
lines changed

src/stage_delegation/delegation.rs

Lines changed: 177 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@ use super::StageContext;
22
use dashmap::{DashMap, Entry};
33
use datafusion::common::{exec_datafusion_err, exec_err};
44
use datafusion::error::DataFusionError;
5-
use std::time::Duration;
5+
use std::ops::Add;
6+
use std::sync::Arc;
7+
use std::time::{Duration, SystemTime, UNIX_EPOCH};
68
use tokio::sync::oneshot;
9+
use tokio::sync::Notify;
10+
use tokio::time;
11+
use tokio::time::Instant;
712

813
/// In each stage of the distributed plan, there will be N workers. All these workers
914
/// need to coordinate to pull data from the next stage, which will contain M workers.
@@ -31,20 +36,67 @@ use tokio::sync::oneshot;
3136
/// [oneshot::Sender], and listen on the other end of the channel [oneshot::Receiver] for
3237
/// the delegate to put something there.
3338
pub struct StageDelegation {
34-
stage_targets: DashMap<(String, usize), Oneof>,
39+
stage_targets: Arc<DashMap<(String, usize), Value>>,
3540
wait_timeout: Duration,
41+
42+
notify: Arc<Notify>,
3643
}
3744

3845
impl Default for StageDelegation {
3946
fn default() -> Self {
40-
Self {
41-
stage_targets: DashMap::default(),
47+
let stage_targets = Arc::new(DashMap::default());
48+
let notify = Arc::new(Notify::new());
49+
50+
let result = Self {
51+
stage_targets: stage_targets.clone(),
4252
wait_timeout: Duration::from_secs(5),
53+
54+
notify: notify.clone(),
55+
};
56+
57+
tokio::spawn(run_gc_async(
58+
stage_targets.clone(),
59+
notify.clone(),
60+
Duration::from_secs(30), /* gc period */
61+
));
62+
63+
result
64+
}
65+
}
66+
67+
// gc_interval is the period over which gc runs to purge old stage_targets entries which will
68+
// never be read. This may happen if the actor encounters an error before it can read
69+
// the delagate info.
70+
async fn run_gc_async(
71+
stage_targets: Arc<DashMap<(String, usize), Value>>,
72+
shutdown: Arc<Notify>,
73+
period: Duration,
74+
) {
75+
loop {
76+
tokio::select! {
77+
_ = shutdown.notified() => {
78+
break;
79+
}
80+
_ = tokio::time::sleep(period) => {
81+
// PERF: This iterator is sharded, so it won't lock the whole map.
82+
stage_targets.retain(|_key, value| {
83+
value.expiry.gt(&Instant::now())
84+
});
85+
}
4386
}
4487
}
4588
}
4689

90+
impl Drop for StageDelegation {
91+
fn drop(&mut self) {
92+
self.notify.notify_one();
93+
}
94+
}
95+
4796
impl StageDelegation {
97+
fn gc_ttl(&self) -> Duration {
98+
self.wait_timeout * 2
99+
}
48100
/// Puts the [StageContext] info so that an actor can pick it up with `wait_for_delegate_info`.
49101
///
50102
/// - If the actor was already waiting for this info, it just puts it on the
@@ -57,9 +109,13 @@ impl StageDelegation {
57109
actor_idx: usize,
58110
next_stage_context: StageContext,
59111
) -> Result<(), DataFusionError> {
112+
let now = SystemTime::now()
113+
.duration_since(UNIX_EPOCH)
114+
.unwrap()
115+
.as_secs();
60116
let tx = match self.stage_targets.entry((stage_id, actor_idx)) {
61-
Entry::Occupied(entry) => match entry.get() {
62-
Oneof::Sender(_) => match entry.remove() {
117+
Entry::Occupied(entry) => match entry.get().value {
118+
Oneof::Sender(_) => match entry.remove().value {
63119
Oneof::Sender(tx) => tx,
64120
Oneof::Receiver(_) => unreachable!(),
65121
},
@@ -69,7 +125,11 @@ impl StageDelegation {
69125
},
70126
Entry::Vacant(entry) => {
71127
let (tx, rx) = oneshot::channel();
72-
entry.insert(Oneof::Receiver(rx));
128+
entry.insert(Value {
129+
// Use 2 * the waiter wait duration for now.
130+
expiry: Instant::now().add(self.gc_ttl()),
131+
value: Oneof::Receiver(rx),
132+
});
73133
tx
74134
}
75135
};
@@ -95,16 +155,25 @@ impl StageDelegation {
95155
actor_idx: usize,
96156
) -> Result<StageContext, DataFusionError> {
97157
let rx = match self.stage_targets.entry((stage_id.clone(), actor_idx)) {
98-
Entry::Occupied(entry) => match entry.get() {
99-
Oneof::Sender(_) => return exec_err!("Programming error: while waiting for delegate info the entry in the StageDelegation target map cannot be a Sender"),
100-
Oneof::Receiver(_) => match entry.remove() {
158+
Entry::Occupied(entry) => match entry.get().value {
159+
Oneof::Sender(_) => {
160+
return exec_err!(
161+
"Programming error: while waiting for delegate info the entry in the \
162+
StageDelegation target map cannot be a Sender"
163+
)
164+
}
165+
Oneof::Receiver(_) => match entry.remove().value {
101166
Oneof::Sender(_) => unreachable!(),
102-
Oneof::Receiver(rx) => rx
167+
Oneof::Receiver(rx) => rx,
103168
},
104169
},
105170
Entry::Vacant(entry) => {
106171
let (tx, rx) = oneshot::channel();
107-
entry.insert(Oneof::Sender(tx));
172+
entry.insert(Value {
173+
// Use 2 * the waiter wait duration for now.
174+
expiry: Instant::now().add(self.gc_ttl()),
175+
value: Oneof::Sender(tx),
176+
});
108177
rx
109178
}
110179
};
@@ -120,6 +189,11 @@ impl StageDelegation {
120189
}
121190
}
122191

192+
struct Value {
193+
expiry: Instant,
194+
value: Oneof,
195+
}
196+
123197
enum Oneof {
124198
Sender(oneshot::Sender<StageContext>),
125199
Receiver(oneshot::Receiver<StageContext>),
@@ -129,6 +203,7 @@ enum Oneof {
129203
mod tests {
130204
use super::*;
131205
use crate::stage_delegation::StageContext;
206+
use futures::TryFutureExt;
132207
use std::sync::Arc;
133208
use uuid::Uuid;
134209

@@ -222,6 +297,7 @@ mod tests {
222297

223298
let received_context = wait_task1.await.unwrap().unwrap();
224299
assert_eq!(received_context.id, stage_context.id);
300+
assert_eq!(0, delegation.stage_targets.len())
225301
}
226302

227303
#[tokio::test]
@@ -287,4 +363,93 @@ mod tests {
287363
.unwrap();
288364
assert_eq!(received_context, stage_context);
289365
}
366+
367+
#[tokio::test]
368+
async fn test_waiter_timeout_and_gc_cleanup() {
369+
let stage_targets = Arc::new(DashMap::default());
370+
let shutdown = Arc::new(Notify::new());
371+
let delegation = StageDelegation {
372+
stage_targets: stage_targets.clone(),
373+
wait_timeout: Duration::from_millis(1),
374+
notify: shutdown.clone(),
375+
};
376+
let stage_id = Uuid::new_v4().to_string();
377+
378+
// Actor waits but times out
379+
let result = delegation.wait_for_delegate_info(stage_id, 0).await;
380+
381+
assert!(result.is_err());
382+
assert!(result.unwrap_err().to_string().contains("Timeout"));
383+
384+
// Wait for expiry time to pass.
385+
tokio::time::sleep(delegation.gc_ttl()).await;
386+
387+
// Run GC to cleanup expired entries
388+
let gc_task = tokio::spawn(run_gc_async(
389+
stage_targets.clone(),
390+
shutdown.clone(),
391+
Duration::from_millis(5),
392+
));
393+
394+
// Wait for GC to clear the map
395+
for _ in 0..50 {
396+
tokio::time::sleep(Duration::from_millis(10)).await;
397+
if stage_targets.len() == 0 {
398+
break;
399+
}
400+
}
401+
402+
// Stop GC by dropping the delegation. Assert that it has shutdown.
403+
drop(delegation);
404+
gc_task.await.unwrap();
405+
406+
// After GC, map should be cleared
407+
assert_eq!(stage_targets.len(), 0);
408+
}
409+
410+
#[tokio::test]
411+
async fn test_writer_only_and_gc_cleanup() {
412+
let stage_targets = Arc::new(DashMap::default());
413+
let shutdown = Arc::new(Notify::new());
414+
let delegation = StageDelegation {
415+
stage_targets: stage_targets.clone(),
416+
wait_timeout: Duration::from_millis(1),
417+
notify: shutdown.clone(),
418+
};
419+
let stage_id = Uuid::new_v4().to_string();
420+
let stage_context = create_test_stage_context();
421+
422+
// Writer adds info without anyone waiting
423+
let result = delegation.add_delegate_info(stage_id, 0, stage_context);
424+
425+
assert!(result.is_ok());
426+
427+
// Entry should be in map
428+
assert_eq!(stage_targets.len(), 1);
429+
430+
// Wait for expiry time to pass (gc_ttl is 2 * wait_timeout)
431+
tokio::time::sleep(delegation.gc_ttl()).await;
432+
433+
// Run GC to cleanup expired entries
434+
let gc_task = tokio::spawn(run_gc_async(
435+
stage_targets.clone(),
436+
shutdown.clone(),
437+
Duration::from_millis(10),
438+
));
439+
440+
// Wait for GC to clear the map
441+
for _ in 0..50 {
442+
tokio::time::sleep(Duration::from_millis(20)).await;
443+
if stage_targets.len() == 0 {
444+
break;
445+
}
446+
}
447+
448+
// Stop GC
449+
drop(delegation);
450+
gc_task.await.unwrap();
451+
452+
// After GC, map should be cleared
453+
assert_eq!(stage_targets.len(), 0);
454+
}
290455
}

0 commit comments

Comments
 (0)