Skip to content

Commit 8f79aa3

Browse files
committed
add workstealng tests
1 parent 59ea7f0 commit 8f79aa3

File tree

5 files changed

+78
-32
lines changed

5 files changed

+78
-32
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ name = "perf"
4040
harness = false
4141

4242
[features]
43-
default = ["affinity", "batching", "retry"]
43+
default = []
4444
affinity = ["dep:core_affinity"]
4545
batching = []
4646
retry = []

src/hive/inner/queue/channel.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use std::sync::Arc;
1010
use std::time::Duration;
1111

1212
// time to wait when polling the global queue
13-
const RECV_TIMEOUT: Duration = Duration::from_secs(1);
13+
const RECV_TIMEOUT: Duration = Duration::from_millis(100);
1414

1515
/// Type alias for the input task channel sender
1616
type TaskSender<W> = crossbeam_channel::Sender<Task<W>>;

src/hive/inner/queue/workstealing.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ use parking_lot::RwLock;
1414
use rand::prelude::*;
1515
use std::ops::Deref;
1616
use std::sync::Arc;
17+
use std::thread;
18+
use std::time::Duration;
19+
20+
/// Time to wait after trying to pop and finding all queues empty.
21+
const EMPTY_DELAY: Duration = Duration::from_millis(100);
1722

1823
pub struct WorkstealingTaskQueues<W: Worker> {
1924
global: Arc<GlobalQueue<W>>,
@@ -40,7 +45,7 @@ impl<W: Worker> TaskQueues<W> for WorkstealingTaskQueues<W> {
4045

4146
fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config) {
4247
let local_queues = self.local.read();
43-
assert!(local_queues.len() > end_index);
48+
assert!(local_queues.len() >= end_index);
4449
(start_index..end_index).for_each(|thread_index| local_queues[thread_index].update(config));
4550
}
4651

@@ -103,14 +108,22 @@ impl<W: Worker> GlobalQueue<W> {
103108
}
104109

105110
/// Tries to steal a task from a random worker using its `Stealer`.
106-
fn try_steal(&self) -> Option<Task<W>> {
111+
fn try_steal_from_worker(&self) -> Result<Task<W>, PopTaskError> {
107112
let stealers = self.stealers.read();
108113
let n = stealers.len();
109114
// randomize the stealing order, to prevent always stealing from the same thread
110115
std::iter::from_fn(|| Some(rand::rng().random_range(0..n)))
111116
.take(n)
112117
.filter_map(|i| stealers[i].steal().success())
113118
.next()
119+
.ok_or_else(|| {
120+
if self.is_closed() && self.queue.is_empty() {
121+
PopTaskError::Closed
122+
} else {
123+
thread::park_timeout(EMPTY_DELAY);
124+
PopTaskError::Empty
125+
}
126+
})
114127
}
115128

116129
/// Tries to steal a task from the global queue, otherwise tries to steal a task from another
@@ -119,7 +132,7 @@ impl<W: Worker> GlobalQueue<W> {
119132
if let Some(task) = self.queue.steal().success() {
120133
Ok(task)
121134
} else {
122-
self.try_steal().ok_or(PopTaskError::Empty)
135+
self.try_steal_from_worker()
123136
}
124137
}
125138

@@ -137,9 +150,10 @@ impl<W: Worker> GlobalQueue<W> {
137150
.steal_batch_with_limit_and_pop(local_batch, limit + 1)
138151
.success()
139152
{
140-
return Ok(task);
153+
Ok(task)
154+
} else {
155+
self.try_steal_from_worker()
141156
}
142-
self.try_steal().ok_or(PopTaskError::Empty)
143157
}
144158

145159
fn is_closed(&self) -> bool {

src/hive/inner/shared.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,8 @@ mod retry {
717717

718718
#[cfg(test)]
719719
mod tests {
720-
use crate::bee::stock::ThunkWorker;
721720
use crate::bee::DefaultQueen;
721+
use crate::bee::stock::ThunkWorker;
722722
use crate::hive::ChannelTaskQueues;
723723

724724
type VoidThunkWorker = ThunkWorker<()>;

src/hive/mod.rs

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2025,6 +2025,7 @@ mod batching_tests {
20252025
hive: &Hive<DefaultQueen<ThunkWorker<ThreadId>>, T>,
20262026
num_threads: usize,
20272027
batch_limit: usize,
2028+
assert_exact: bool,
20282029
) {
20292030
let tasks_per_thread = batch_limit + 2;
20302031
let (tx, rx) = crate::hive::outcome_channel();
@@ -2038,52 +2039,83 @@ mod batching_tests {
20382039
hive.join();
20392040
let thread_counts = count_thread_ids(rx, task_ids);
20402041
assert_eq!(thread_counts.len(), num_threads);
2041-
assert!(
2042-
thread_counts
2043-
.values()
2044-
.all(|&count| count == tasks_per_thread)
2042+
assert_eq!(
2043+
thread_counts.values().sum::<usize>(),
2044+
tasks_per_thread * num_threads
20452045
);
2046+
if assert_exact {
2047+
assert!(
2048+
thread_counts
2049+
.values()
2050+
.all(|&count| count == tasks_per_thread)
2051+
);
2052+
} else {
2053+
assert!(thread_counts.values().all(|&count| count > 0));
2054+
}
20462055
}
20472056

20482057
#[rstest]
2049-
fn test_batching<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
2050-
where
2051-
B: TaskQueuesBuilder,
2052-
F: Fn(bool) -> B,
2053-
{
2058+
fn test_batching_channel() {
20542059
const NUM_THREADS: usize = 4;
20552060
const BATCH_LIMIT: usize = 24;
2056-
let hive = builder_factory(false)
2061+
let hive = channel_builder(false)
20572062
.with_worker_default()
20582063
.num_threads(NUM_THREADS)
20592064
.batch_limit(BATCH_LIMIT)
20602065
.build();
2061-
run_test(&hive, NUM_THREADS, BATCH_LIMIT);
2066+
run_test(&hive, NUM_THREADS, BATCH_LIMIT, true);
20622067
}
20632068

20642069
#[rstest]
2065-
fn test_set_batch_limit<B, F>(
2066-
#[values(channel_builder, workstealing_builder)] builder_factory: F,
2067-
) where
2068-
B: TaskQueuesBuilder,
2069-
F: Fn(bool) -> B,
2070-
{
2070+
fn test_batching_workstealing() {
2071+
const NUM_THREADS: usize = 4;
2072+
const BATCH_LIMIT: usize = 24;
2073+
let hive = workstealing_builder(false)
2074+
.with_worker_default()
2075+
.num_threads(NUM_THREADS)
2076+
.batch_limit(BATCH_LIMIT)
2077+
.build();
2078+
run_test(&hive, NUM_THREADS, BATCH_LIMIT, false);
2079+
}
2080+
2081+
#[rstest]
2082+
fn test_set_batch_limit_channel() {
20712083
const NUM_THREADS: usize = 4;
20722084
const BATCH_LIMIT_0: usize = 10;
2073-
const BATCH_LIMIT_1: usize = 20;
2074-
const BATCH_LIMIT_2: usize = 50;
2075-
let hive = builder_factory(false)
2085+
const BATCH_LIMIT_1: usize = 50;
2086+
const BATCH_LIMIT_2: usize = 20;
2087+
let hive = channel_builder(false)
20762088
.with_worker_default()
20772089
.num_threads(NUM_THREADS)
20782090
.batch_limit(BATCH_LIMIT_0)
20792091
.build();
2080-
run_test(&hive, NUM_THREADS, BATCH_LIMIT_0);
2092+
run_test(&hive, NUM_THREADS, BATCH_LIMIT_0, true);
20812093
// increase batch size
2082-
hive.set_worker_batch_limit(BATCH_LIMIT_2);
2083-
run_test(&hive, NUM_THREADS, BATCH_LIMIT_2);
2094+
hive.set_worker_batch_limit(BATCH_LIMIT_1);
2095+
run_test(&hive, NUM_THREADS, BATCH_LIMIT_1, true);
20842096
// decrease batch size
2097+
hive.set_worker_batch_limit(BATCH_LIMIT_2);
2098+
run_test(&hive, NUM_THREADS, BATCH_LIMIT_2, true);
2099+
}
2100+
2101+
#[rstest]
2102+
fn test_set_batch_limit_workstealing() {
2103+
const NUM_THREADS: usize = 4;
2104+
const BATCH_LIMIT_0: usize = 10;
2105+
const BATCH_LIMIT_1: usize = 50;
2106+
const BATCH_LIMIT_2: usize = 20;
2107+
let hive = workstealing_builder(false)
2108+
.with_worker_default()
2109+
.num_threads(NUM_THREADS)
2110+
.batch_limit(BATCH_LIMIT_0)
2111+
.build();
2112+
run_test(&hive, NUM_THREADS, BATCH_LIMIT_0, false);
2113+
// increase batch size
20852114
hive.set_worker_batch_limit(BATCH_LIMIT_1);
2086-
run_test(&hive, NUM_THREADS, BATCH_LIMIT_1);
2115+
run_test(&hive, NUM_THREADS, BATCH_LIMIT_1, false);
2116+
// decrease batch size
2117+
hive.set_worker_batch_limit(BATCH_LIMIT_2);
2118+
run_test(&hive, NUM_THREADS, BATCH_LIMIT_2, false);
20872119
}
20882120

20892121
#[rstest]

0 commit comments

Comments
 (0)