Skip to content

Commit c8acb3a

Browse files
committed
feat: batching util support optional batch size
1 parent c1eaee9 commit c8acb3a

File tree

3 files changed

+284
-16
lines changed

3 files changed

+284
-16
lines changed

src/execution/source_indexer.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,10 @@ impl SourceIndexingContext {
304304
rows_to_retry,
305305
}),
306306
setup_execution_ctx,
307-
update_once_batcher: batching::Batcher::new(UpdateOnceRunner),
307+
update_once_batcher: batching::Batcher::new(
308+
UpdateOnceRunner,
309+
batching::BatcherOptions::default(),
310+
),
308311
}))
309312
}
310313

src/ops/factory_bases.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ impl<E: BatchedFunctionExecutor> BatchedFunctionExecutorWrapper<E> {
407407
Self {
408408
enable_cache: executor.enable_cache(),
409409
behavior_version: executor.behavior_version(),
410-
batcher: batching::Batcher::new(executor),
410+
batcher: batching::Batcher::new(executor, batching::BatcherOptions::default()),
411411
}
412412
}
413413
}

src/utils/batching.rs

Lines changed: 279 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ impl<I, O> Default for Batch<I, O> {
3636
enum BatcherState<I, O> {
3737
#[default]
3838
Idle,
39-
Busy(Option<Batch<I, O>>),
39+
Busy {
40+
pending_batch: Option<Batch<I, O>>,
41+
ongoing_count: usize,
42+
},
4043
}
4144

4245
struct BatcherData<R: Runner + 'static> {
@@ -95,6 +98,7 @@ impl<R: Runner + 'static> BatcherData<R> {
9598

9699
pub struct Batcher<R: Runner + 'static> {
97100
data: Arc<BatcherData<R>>,
101+
options: BatcherOptions,
98102
}
99103

100104
enum BatchExecutionAction<R: Runner + 'static> {
@@ -106,33 +110,62 @@ enum BatchExecutionAction<R: Runner + 'static> {
106110
num_cancelled_tx: watch::Sender<usize>,
107111
},
108112
}
113+
114+
#[derive(Default)]
115+
pub struct BatcherOptions {
116+
pub max_batch_size: Option<usize>,
117+
}
109118
impl<R: Runner + 'static> Batcher<R> {
110-
pub fn new(runner: R) -> Self {
119+
pub fn new(runner: R, options: BatcherOptions) -> Self {
111120
Self {
112121
data: Arc::new(BatcherData {
113122
runner,
114123
state: Mutex::new(BatcherState::Idle),
115124
}),
125+
options,
116126
}
117127
}
118128
pub async fn run(&self, input: R::Input) -> Result<R::Output> {
119129
let batch_exec_action: BatchExecutionAction<R> = {
120130
let mut state = self.data.state.lock().unwrap();
121131
match &mut *state {
122132
state @ BatcherState::Idle => {
123-
*state = BatcherState::Busy(None);
133+
*state = BatcherState::Busy {
134+
pending_batch: None,
135+
ongoing_count: 1,
136+
};
124137
BatchExecutionAction::Inline { input }
125138
}
126-
BatcherState::Busy(batch) => {
127-
let batch = batch.get_or_insert_default();
139+
BatcherState::Busy {
140+
pending_batch,
141+
ongoing_count,
142+
} => {
143+
let batch = pending_batch.get_or_insert_default();
128144
batch.inputs.push(input);
129145

130146
let (output_tx, output_rx) = oneshot::channel();
131147
batch.output_txs.push(output_tx);
132148

149+
let num_cancelled_tx = batch.num_cancelled_tx.clone();
150+
151+
// Check if we've reached max_batch_size and need to flush immediately
152+
let should_flush = self
153+
.options
154+
.max_batch_size
155+
.map(|max_size| batch.inputs.len() >= max_size)
156+
.unwrap_or(false);
157+
158+
if should_flush {
159+
// Take the batch and trigger execution
160+
let batch_to_run = pending_batch.take().unwrap();
161+
*ongoing_count += 1;
162+
let data = self.data.clone();
163+
tokio::spawn(async move { data.run_batch(batch_to_run).await });
164+
}
165+
133166
BatchExecutionAction::Batched {
134167
output_rx,
135-
num_cancelled_tx: batch.num_cancelled_tx.clone(),
168+
num_cancelled_tx,
136169
}
137170
}
138171
}
@@ -173,13 +206,33 @@ struct BatchKickOffNext<'a, R: Runner + 'static> {
173206
impl<'a, R: Runner + 'static> Drop for BatchKickOffNext<'a, R> {
174207
fn drop(&mut self) {
175208
let mut state = self.batcher_data.state.lock().unwrap();
176-
let existing_state = std::mem::take(&mut *state);
177-
let BatcherState::Busy(Some(batch)) = existing_state else {
178-
return;
179-
};
180-
*state = BatcherState::Busy(None);
181-
let data = self.batcher_data.clone();
182-
tokio::spawn(async move { data.run_batch(batch).await });
209+
210+
match &mut *state {
211+
BatcherState::Idle => {
212+
// Nothing to do, already idle
213+
return;
214+
}
215+
BatcherState::Busy {
216+
pending_batch,
217+
ongoing_count,
218+
} => {
219+
// Decrement the ongoing count first
220+
*ongoing_count -= 1;
221+
222+
if *ongoing_count == 0 {
223+
// All batches done, check if there's a pending batch
224+
if let Some(batch) = pending_batch.take() {
225+
// Kick off the pending batch and set ongoing_count to 1
226+
*ongoing_count = 1;
227+
let data = self.batcher_data.clone();
228+
tokio::spawn(async move { data.run_batch(batch).await });
229+
} else {
230+
// No pending batch, transition to Idle
231+
*state = BatcherState::Idle;
232+
}
233+
}
234+
}
235+
}
183236
}
184237
}
185238

@@ -263,7 +316,7 @@ mod tests {
263316
let runner = TestRunner {
264317
recorded_calls: recorded_calls.clone(),
265318
};
266-
let batcher = Arc::new(Batcher::new(runner));
319+
let batcher = Arc::new(Batcher::new(runner, BatcherOptions::default()));
267320

268321
let (n1_tx, n1_rx) = oneshot::channel::<()>();
269322
let (n2_tx, n2_rx) = oneshot::channel::<()>();
@@ -319,4 +372,216 @@ mod tests {
319372

320373
Ok(())
321374
}
375+
376+
#[tokio::test(flavor = "current_thread")]
377+
async fn respects_max_batch_size() -> Result<()> {
378+
let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new()));
379+
let runner = TestRunner {
380+
recorded_calls: recorded_calls.clone(),
381+
};
382+
let batcher = Arc::new(Batcher::new(
383+
runner,
384+
BatcherOptions {
385+
max_batch_size: Some(2),
386+
},
387+
));
388+
389+
let (n1_tx, n1_rx) = oneshot::channel::<()>();
390+
let (n2_tx, n2_rx) = oneshot::channel::<()>();
391+
let (n3_tx, n3_rx) = oneshot::channel::<()>();
392+
let (n4_tx, n4_rx) = oneshot::channel::<()>();
393+
394+
// Submit first call; it should execute inline and block on n1
395+
let b1 = batcher.clone();
396+
let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await });
397+
398+
// Wait until the runner has recorded the first inline call
399+
wait_until_len(&recorded_calls, 1).await;
400+
401+
// Submit second call; it should be batched
402+
let b2 = batcher.clone();
403+
let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await });
404+
405+
// Submit third call; this should trigger a flush because max_batch_size=2
406+
// The batch [2, 3] should be executed immediately
407+
let b3 = batcher.clone();
408+
let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await });
409+
410+
// Wait for the second batch to be recorded
411+
wait_until_len(&recorded_calls, 2).await;
412+
413+
// Verify that the second batch was triggered by max_batch_size
414+
{
415+
let calls = recorded_calls.lock().unwrap();
416+
assert_eq!(calls.len(), 2, "second batch should have started");
417+
assert_eq!(calls[1], vec![2, 3], "second batch should contain [2, 3]");
418+
}
419+
420+
// Submit fourth call; it should wait because there are still ongoing batches
421+
let b4 = batcher.clone();
422+
let f4 = tokio::spawn(async move { b4.run((4_i64, n4_rx)).await });
423+
424+
// Give it a moment to ensure no new batch starts
425+
sleep(Duration::from_millis(50)).await;
426+
{
427+
let len_now = recorded_calls.lock().unwrap().len();
428+
assert_eq!(
429+
len_now, 2,
430+
"third batch should not start until all ongoing batches complete"
431+
);
432+
}
433+
434+
// Unblock the first inline call
435+
let _ = n1_tx.send(());
436+
437+
// Wait for first result
438+
let v1 = f1.await??;
439+
assert_eq!(v1, 2);
440+
441+
// Batch [2,3] is still running, so batch [4] shouldn't start yet
442+
sleep(Duration::from_millis(50)).await;
443+
{
444+
let len_now = recorded_calls.lock().unwrap().len();
445+
assert_eq!(
446+
len_now, 2,
447+
"third batch should not start until all ongoing batches complete"
448+
);
449+
}
450+
451+
// Unblock batch [2,3] - this should trigger batch [4] to start
452+
let _ = n2_tx.send(());
453+
let _ = n3_tx.send(());
454+
455+
let v2 = f2.await??;
456+
let v3 = f3.await??;
457+
assert_eq!(v2, 4);
458+
assert_eq!(v3, 6);
459+
460+
// Now batch [4] should start since all previous batches are done
461+
wait_until_len(&recorded_calls, 3).await;
462+
463+
// Unblock batch [4]
464+
let _ = n4_tx.send(());
465+
let v4 = f4.await??;
466+
assert_eq!(v4, 8);
467+
468+
// Validate the call recording: [1], [2, 3] (flushed by max_batch_size), [4]
469+
let calls = recorded_calls.lock().unwrap().clone();
470+
assert_eq!(calls.len(), 3);
471+
assert_eq!(calls[0], vec![1]);
472+
assert_eq!(calls[1], vec![2, 3]);
473+
assert_eq!(calls[2], vec![4]);
474+
475+
Ok(())
476+
}
477+
478+
#[tokio::test(flavor = "current_thread")]
479+
async fn tracks_multiple_concurrent_batches() -> Result<()> {
480+
let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new()));
481+
let runner = TestRunner {
482+
recorded_calls: recorded_calls.clone(),
483+
};
484+
let batcher = Arc::new(Batcher::new(
485+
runner,
486+
BatcherOptions {
487+
max_batch_size: Some(2),
488+
},
489+
));
490+
491+
let (n1_tx, n1_rx) = oneshot::channel::<()>();
492+
let (n2_tx, n2_rx) = oneshot::channel::<()>();
493+
let (n3_tx, n3_rx) = oneshot::channel::<()>();
494+
let (n4_tx, n4_rx) = oneshot::channel::<()>();
495+
let (n5_tx, n5_rx) = oneshot::channel::<()>();
496+
let (n6_tx, n6_rx) = oneshot::channel::<()>();
497+
498+
// Submit first call - executes inline
499+
let b1 = batcher.clone();
500+
let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await });
501+
wait_until_len(&recorded_calls, 1).await;
502+
503+
// Submit calls 2-3 - should batch and flush at max_batch_size
504+
let b2 = batcher.clone();
505+
let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await });
506+
let b3 = batcher.clone();
507+
let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await });
508+
wait_until_len(&recorded_calls, 2).await;
509+
510+
// Submit calls 4-5 - should batch and flush at max_batch_size
511+
let b4 = batcher.clone();
512+
let f4 = tokio::spawn(async move { b4.run((4_i64, n4_rx)).await });
513+
let b5 = batcher.clone();
514+
let f5 = tokio::spawn(async move { b5.run((5_i64, n5_rx)).await });
515+
wait_until_len(&recorded_calls, 3).await;
516+
517+
// Submit call 6 - should be batched but not flushed yet
518+
let b6 = batcher.clone();
519+
let f6 = tokio::spawn(async move { b6.run((6_i64, n6_rx)).await });
520+
521+
// Give it a moment to ensure no new batch starts
522+
sleep(Duration::from_millis(50)).await;
523+
{
524+
let len_now = recorded_calls.lock().unwrap().len();
525+
assert_eq!(
526+
len_now, 3,
527+
"fourth batch should not start with ongoing batches"
528+
);
529+
}
530+
531+
// Unblock batch [2, 3] - should not cause [6] to execute yet (batch 1 still ongoing)
532+
let _ = n2_tx.send(());
533+
let _ = n3_tx.send(());
534+
let v2 = f2.await??;
535+
let v3 = f3.await??;
536+
assert_eq!(v2, 4);
537+
assert_eq!(v3, 6);
538+
539+
sleep(Duration::from_millis(50)).await;
540+
{
541+
let len_now = recorded_calls.lock().unwrap().len();
542+
assert_eq!(
543+
len_now, 3,
544+
"batch [6] should still not start (batch 1 and batch [4,5] still ongoing)"
545+
);
546+
}
547+
548+
// Unblock batch [4, 5] - should not cause [6] to execute yet (batch 1 still ongoing)
549+
let _ = n4_tx.send(());
550+
let _ = n5_tx.send(());
551+
let v4 = f4.await??;
552+
let v5 = f5.await??;
553+
assert_eq!(v4, 8);
554+
assert_eq!(v5, 10);
555+
556+
sleep(Duration::from_millis(50)).await;
557+
{
558+
let len_now = recorded_calls.lock().unwrap().len();
559+
assert_eq!(
560+
len_now, 3,
561+
"batch [6] should still not start (batch 1 still ongoing)"
562+
);
563+
}
564+
565+
// Unblock batch 1 - NOW batch [6] should start
566+
let _ = n1_tx.send(());
567+
let v1 = f1.await??;
568+
assert_eq!(v1, 2);
569+
570+
wait_until_len(&recorded_calls, 4).await;
571+
572+
// Unblock batch [6]
573+
let _ = n6_tx.send(());
574+
let v6 = f6.await??;
575+
assert_eq!(v6, 12);
576+
577+
// Validate the call recording
578+
let calls = recorded_calls.lock().unwrap().clone();
579+
assert_eq!(calls.len(), 4);
580+
assert_eq!(calls[0], vec![1]);
581+
assert_eq!(calls[1], vec![2, 3]);
582+
assert_eq!(calls[2], vec![4, 5]);
583+
assert_eq!(calls[3], vec![6]);
584+
585+
Ok(())
586+
}
322587
}

0 commit comments

Comments
 (0)