Skip to content

Commit 0ef85b5

Browse files
committed
better batch handling
1 parent 6f66d52 commit 0ef85b5

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

compute/src/payloads/stats.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl TaskStats {
5151
}
5252

5353
/// Records the execution time of the task.
54-
#[deprecated = "will be removed later"]
54+
/// TODO: #[deprecated = "will be removed later"]
5555
pub fn record_execution_time(mut self, started_at: Instant) -> Self {
5656
self.execution_time = Instant::now().duration_since(started_at).as_nanos();
5757
self

compute/src/workers/workflow.rs

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,39 @@ impl WorkflowsWorker {
9393
///
9494
/// Batch size must NOT be larger than `MAX_BATCH_SIZE`, otherwise will panic.
9595
pub async fn run_batch(&mut self, batch_size: usize) {
96-
// TODO: need some better batch_size error handling here
96+
assert!(
97+
batch_size <= Self::MAX_BATCH_SIZE,
98+
"Batch size must not be larger than {}",
99+
Self::MAX_BATCH_SIZE
100+
);
101+
97102
loop {
98-
// get tasks in batch from the channel
99-
let mut task_buffer = Vec::new();
100-
let num_tasks = self
101-
.workflow_rx
102-
.recv_many(&mut task_buffer, batch_size)
103-
.await;
104-
105-
if num_tasks == 0 {
106-
return self.shutdown();
103+
let mut tasks = Vec::new();
104+
105+
// get tasks in batch from the channel, we enter the loop if:
106+
// (1) there are no tasks, or,
107+
// (2) there are tasks less than the batch size and the channel is not empty
108+
while tasks.len() == 0 || (tasks.len() < batch_size && !self.workflow_rx.is_empty()) {
109+
let limit = batch_size - tasks.len();
110+
match self.workflow_rx.recv_many(&mut tasks, limit).await {
111+
// 0 tasks returned means that the channel is closed
112+
0 => return self.shutdown(),
113+
_ => {
114+
// wait a small amount of time to allow for more tasks to be sent into the channel
115+
tokio::time::sleep(std::time::Duration::from_millis(256)).await;
116+
}
117+
}
107118
}
108119

109120
// process the batch
121+
let num_tasks = tasks.len();
122+
debug_assert!(
123+
num_tasks <= batch_size,
124+
"number of tasks cant be larger than batch size"
125+
);
126+
debug_assert!(num_tasks != 0, "number of tasks cant be zero");
110127
log::info!("Processing {} workflows in batch", num_tasks);
111-
let mut batch = task_buffer.into_iter().map(|b| (b, &self.publish_tx));
128+
let mut batch = tasks.into_iter().map(|b| (b, &self.publish_tx));
112129
match num_tasks {
113130
1 => {
114131
WorkflowsWorker::execute(batch.next().unwrap()).await;

0 commit comments

Comments
 (0)