@@ -36,7 +36,10 @@ impl<I, O> Default for Batch<I, O> {
3636enum BatcherState < I , O > {
3737 #[ default]
3838 Idle ,
39- Busy ( Option < Batch < I , O > > ) ,
39+ Busy {
40+ pending_batch : Option < Batch < I , O > > ,
41+ ongoing_count : usize ,
42+ } ,
4043}
4144
4245struct BatcherData < R : Runner + ' static > {
@@ -95,6 +98,7 @@ impl<R: Runner + 'static> BatcherData<R> {
9598
9699pub struct Batcher < R : Runner + ' static > {
97100 data : Arc < BatcherData < R > > ,
101+ options : BatcherOptions ,
98102}
99103
100104enum BatchExecutionAction < R : Runner + ' static > {
@@ -106,33 +110,62 @@ enum BatchExecutionAction<R: Runner + 'static> {
106110 num_cancelled_tx : watch:: Sender < usize > ,
107111 } ,
108112}
113+
114+ #[ derive( Default ) ]
115+ pub struct BatcherOptions {
116+ pub max_batch_size : Option < usize > ,
117+ }
109118impl < R : Runner + ' static > Batcher < R > {
110- pub fn new ( runner : R ) -> Self {
119+ pub fn new ( runner : R , options : BatcherOptions ) -> Self {
111120 Self {
112121 data : Arc :: new ( BatcherData {
113122 runner,
114123 state : Mutex :: new ( BatcherState :: Idle ) ,
115124 } ) ,
125+ options,
116126 }
117127 }
118128 pub async fn run ( & self , input : R :: Input ) -> Result < R :: Output > {
119129 let batch_exec_action: BatchExecutionAction < R > = {
120130 let mut state = self . data . state . lock ( ) . unwrap ( ) ;
121131 match & mut * state {
122132 state @ BatcherState :: Idle => {
123- * state = BatcherState :: Busy ( None ) ;
133+ * state = BatcherState :: Busy {
134+ pending_batch : None ,
135+ ongoing_count : 1 ,
136+ } ;
124137 BatchExecutionAction :: Inline { input }
125138 }
126- BatcherState :: Busy ( batch) => {
127- let batch = batch. get_or_insert_default ( ) ;
139+ BatcherState :: Busy {
140+ pending_batch,
141+ ongoing_count,
142+ } => {
143+ let batch = pending_batch. get_or_insert_default ( ) ;
128144 batch. inputs . push ( input) ;
129145
130146 let ( output_tx, output_rx) = oneshot:: channel ( ) ;
131147 batch. output_txs . push ( output_tx) ;
132148
149+ let num_cancelled_tx = batch. num_cancelled_tx . clone ( ) ;
150+
151+ // Check if we've reached max_batch_size and need to flush immediately
152+ let should_flush = self
153+ . options
154+ . max_batch_size
155+ . map ( |max_size| batch. inputs . len ( ) >= max_size)
156+ . unwrap_or ( false ) ;
157+
158+ if should_flush {
159+ // Take the batch and trigger execution
160+ let batch_to_run = pending_batch. take ( ) . unwrap ( ) ;
161+ * ongoing_count += 1 ;
162+ let data = self . data . clone ( ) ;
163+ tokio:: spawn ( async move { data. run_batch ( batch_to_run) . await } ) ;
164+ }
165+
133166 BatchExecutionAction :: Batched {
134167 output_rx,
135- num_cancelled_tx : batch . num_cancelled_tx . clone ( ) ,
168+ num_cancelled_tx,
136169 }
137170 }
138171 }
@@ -173,13 +206,33 @@ struct BatchKickOffNext<'a, R: Runner + 'static> {
173206impl < ' a , R : Runner + ' static > Drop for BatchKickOffNext < ' a , R > {
174207 fn drop ( & mut self ) {
175208 let mut state = self . batcher_data . state . lock ( ) . unwrap ( ) ;
176- let existing_state = std:: mem:: take ( & mut * state) ;
177- let BatcherState :: Busy ( Some ( batch) ) = existing_state else {
178- return ;
179- } ;
180- * state = BatcherState :: Busy ( None ) ;
181- let data = self . batcher_data . clone ( ) ;
182- tokio:: spawn ( async move { data. run_batch ( batch) . await } ) ;
209+
210+ match & mut * state {
211+ BatcherState :: Idle => {
212+ // Nothing to do, already idle
213+ return ;
214+ }
215+ BatcherState :: Busy {
216+ pending_batch,
217+ ongoing_count,
218+ } => {
219+ // Decrement the ongoing count first
220+ * ongoing_count -= 1 ;
221+
222+ if * ongoing_count == 0 {
223+ // All batches done, check if there's a pending batch
224+ if let Some ( batch) = pending_batch. take ( ) {
225+ // Kick off the pending batch and set ongoing_count to 1
226+ * ongoing_count = 1 ;
227+ let data = self . batcher_data . clone ( ) ;
228+ tokio:: spawn ( async move { data. run_batch ( batch) . await } ) ;
229+ } else {
230+ // No pending batch, transition to Idle
231+ * state = BatcherState :: Idle ;
232+ }
233+ }
234+ }
235+ }
183236 }
184237}
185238
@@ -263,7 +316,7 @@ mod tests {
263316 let runner = TestRunner {
264317 recorded_calls : recorded_calls. clone ( ) ,
265318 } ;
266- let batcher = Arc :: new ( Batcher :: new ( runner) ) ;
319+ let batcher = Arc :: new ( Batcher :: new ( runner, BatcherOptions :: default ( ) ) ) ;
267320
268321 let ( n1_tx, n1_rx) = oneshot:: channel :: < ( ) > ( ) ;
269322 let ( n2_tx, n2_rx) = oneshot:: channel :: < ( ) > ( ) ;
@@ -319,4 +372,216 @@ mod tests {
319372
320373 Ok ( ( ) )
321374 }
375+
376+ #[ tokio:: test( flavor = "current_thread" ) ]
377+ async fn respects_max_batch_size ( ) -> Result < ( ) > {
378+ let recorded_calls = Arc :: new ( Mutex :: new ( Vec :: < Vec < i64 > > :: new ( ) ) ) ;
379+ let runner = TestRunner {
380+ recorded_calls : recorded_calls. clone ( ) ,
381+ } ;
382+ let batcher = Arc :: new ( Batcher :: new (
383+ runner,
384+ BatcherOptions {
385+ max_batch_size : Some ( 2 ) ,
386+ } ,
387+ ) ) ;
388+
389+ let ( n1_tx, n1_rx) = oneshot:: channel :: < ( ) > ( ) ;
390+ let ( n2_tx, n2_rx) = oneshot:: channel :: < ( ) > ( ) ;
391+ let ( n3_tx, n3_rx) = oneshot:: channel :: < ( ) > ( ) ;
392+ let ( n4_tx, n4_rx) = oneshot:: channel :: < ( ) > ( ) ;
393+
394+ // Submit first call; it should execute inline and block on n1
395+ let b1 = batcher. clone ( ) ;
396+ let f1 = tokio:: spawn ( async move { b1. run ( ( 1_i64 , n1_rx) ) . await } ) ;
397+
398+ // Wait until the runner has recorded the first inline call
399+ wait_until_len ( & recorded_calls, 1 ) . await ;
400+
401+ // Submit second call; it should be batched
402+ let b2 = batcher. clone ( ) ;
403+ let f2 = tokio:: spawn ( async move { b2. run ( ( 2_i64 , n2_rx) ) . await } ) ;
404+
405+ // Submit third call; this should trigger a flush because max_batch_size=2
406+ // The batch [2, 3] should be executed immediately
407+ let b3 = batcher. clone ( ) ;
408+ let f3 = tokio:: spawn ( async move { b3. run ( ( 3_i64 , n3_rx) ) . await } ) ;
409+
410+ // Wait for the second batch to be recorded
411+ wait_until_len ( & recorded_calls, 2 ) . await ;
412+
413+ // Verify that the second batch was triggered by max_batch_size
414+ {
415+ let calls = recorded_calls. lock ( ) . unwrap ( ) ;
416+ assert_eq ! ( calls. len( ) , 2 , "second batch should have started" ) ;
417+ assert_eq ! ( calls[ 1 ] , vec![ 2 , 3 ] , "second batch should contain [2, 3]" ) ;
418+ }
419+
420+ // Submit fourth call; it should wait because there are still ongoing batches
421+ let b4 = batcher. clone ( ) ;
422+ let f4 = tokio:: spawn ( async move { b4. run ( ( 4_i64 , n4_rx) ) . await } ) ;
423+
424+ // Give it a moment to ensure no new batch starts
425+ sleep ( Duration :: from_millis ( 50 ) ) . await ;
426+ {
427+ let len_now = recorded_calls. lock ( ) . unwrap ( ) . len ( ) ;
428+ assert_eq ! (
429+ len_now, 2 ,
430+ "third batch should not start until all ongoing batches complete"
431+ ) ;
432+ }
433+
434+ // Unblock the first inline call
435+ let _ = n1_tx. send ( ( ) ) ;
436+
437+ // Wait for first result
438+ let v1 = f1. await ??;
439+ assert_eq ! ( v1, 2 ) ;
440+
441+ // Batch [2,3] is still running, so batch [4] shouldn't start yet
442+ sleep ( Duration :: from_millis ( 50 ) ) . await ;
443+ {
444+ let len_now = recorded_calls. lock ( ) . unwrap ( ) . len ( ) ;
445+ assert_eq ! (
446+ len_now, 2 ,
447+ "third batch should not start until all ongoing batches complete"
448+ ) ;
449+ }
450+
451+ // Unblock batch [2,3] - this should trigger batch [4] to start
452+ let _ = n2_tx. send ( ( ) ) ;
453+ let _ = n3_tx. send ( ( ) ) ;
454+
455+ let v2 = f2. await ??;
456+ let v3 = f3. await ??;
457+ assert_eq ! ( v2, 4 ) ;
458+ assert_eq ! ( v3, 6 ) ;
459+
460+ // Now batch [4] should start since all previous batches are done
461+ wait_until_len ( & recorded_calls, 3 ) . await ;
462+
463+ // Unblock batch [4]
464+ let _ = n4_tx. send ( ( ) ) ;
465+ let v4 = f4. await ??;
466+ assert_eq ! ( v4, 8 ) ;
467+
468+ // Validate the call recording: [1], [2, 3] (flushed by max_batch_size), [4]
469+ let calls = recorded_calls. lock ( ) . unwrap ( ) . clone ( ) ;
470+ assert_eq ! ( calls. len( ) , 3 ) ;
471+ assert_eq ! ( calls[ 0 ] , vec![ 1 ] ) ;
472+ assert_eq ! ( calls[ 1 ] , vec![ 2 , 3 ] ) ;
473+ assert_eq ! ( calls[ 2 ] , vec![ 4 ] ) ;
474+
475+ Ok ( ( ) )
476+ }
477+
478+ #[ tokio:: test( flavor = "current_thread" ) ]
479+ async fn tracks_multiple_concurrent_batches ( ) -> Result < ( ) > {
480+ let recorded_calls = Arc :: new ( Mutex :: new ( Vec :: < Vec < i64 > > :: new ( ) ) ) ;
481+ let runner = TestRunner {
482+ recorded_calls : recorded_calls. clone ( ) ,
483+ } ;
484+ let batcher = Arc :: new ( Batcher :: new (
485+ runner,
486+ BatcherOptions {
487+ max_batch_size : Some ( 2 ) ,
488+ } ,
489+ ) ) ;
490+
491+ let ( n1_tx, n1_rx) = oneshot:: channel :: < ( ) > ( ) ;
492+ let ( n2_tx, n2_rx) = oneshot:: channel :: < ( ) > ( ) ;
493+ let ( n3_tx, n3_rx) = oneshot:: channel :: < ( ) > ( ) ;
494+ let ( n4_tx, n4_rx) = oneshot:: channel :: < ( ) > ( ) ;
495+ let ( n5_tx, n5_rx) = oneshot:: channel :: < ( ) > ( ) ;
496+ let ( n6_tx, n6_rx) = oneshot:: channel :: < ( ) > ( ) ;
497+
498+ // Submit first call - executes inline
499+ let b1 = batcher. clone ( ) ;
500+ let f1 = tokio:: spawn ( async move { b1. run ( ( 1_i64 , n1_rx) ) . await } ) ;
501+ wait_until_len ( & recorded_calls, 1 ) . await ;
502+
503+ // Submit calls 2-3 - should batch and flush at max_batch_size
504+ let b2 = batcher. clone ( ) ;
505+ let f2 = tokio:: spawn ( async move { b2. run ( ( 2_i64 , n2_rx) ) . await } ) ;
506+ let b3 = batcher. clone ( ) ;
507+ let f3 = tokio:: spawn ( async move { b3. run ( ( 3_i64 , n3_rx) ) . await } ) ;
508+ wait_until_len ( & recorded_calls, 2 ) . await ;
509+
510+ // Submit calls 4-5 - should batch and flush at max_batch_size
511+ let b4 = batcher. clone ( ) ;
512+ let f4 = tokio:: spawn ( async move { b4. run ( ( 4_i64 , n4_rx) ) . await } ) ;
513+ let b5 = batcher. clone ( ) ;
514+ let f5 = tokio:: spawn ( async move { b5. run ( ( 5_i64 , n5_rx) ) . await } ) ;
515+ wait_until_len ( & recorded_calls, 3 ) . await ;
516+
517+ // Submit call 6 - should be batched but not flushed yet
518+ let b6 = batcher. clone ( ) ;
519+ let f6 = tokio:: spawn ( async move { b6. run ( ( 6_i64 , n6_rx) ) . await } ) ;
520+
521+ // Give it a moment to ensure no new batch starts
522+ sleep ( Duration :: from_millis ( 50 ) ) . await ;
523+ {
524+ let len_now = recorded_calls. lock ( ) . unwrap ( ) . len ( ) ;
525+ assert_eq ! (
526+ len_now, 3 ,
527+ "fourth batch should not start with ongoing batches"
528+ ) ;
529+ }
530+
531+ // Unblock batch [2, 3] - should not cause [6] to execute yet (batch 1 still ongoing)
532+ let _ = n2_tx. send ( ( ) ) ;
533+ let _ = n3_tx. send ( ( ) ) ;
534+ let v2 = f2. await ??;
535+ let v3 = f3. await ??;
536+ assert_eq ! ( v2, 4 ) ;
537+ assert_eq ! ( v3, 6 ) ;
538+
539+ sleep ( Duration :: from_millis ( 50 ) ) . await ;
540+ {
541+ let len_now = recorded_calls. lock ( ) . unwrap ( ) . len ( ) ;
542+ assert_eq ! (
543+ len_now, 3 ,
544+ "batch [6] should still not start (batch 1 and batch [4,5] still ongoing)"
545+ ) ;
546+ }
547+
548+ // Unblock batch [4, 5] - should not cause [6] to execute yet (batch 1 still ongoing)
549+ let _ = n4_tx. send ( ( ) ) ;
550+ let _ = n5_tx. send ( ( ) ) ;
551+ let v4 = f4. await ??;
552+ let v5 = f5. await ??;
553+ assert_eq ! ( v4, 8 ) ;
554+ assert_eq ! ( v5, 10 ) ;
555+
556+ sleep ( Duration :: from_millis ( 50 ) ) . await ;
557+ {
558+ let len_now = recorded_calls. lock ( ) . unwrap ( ) . len ( ) ;
559+ assert_eq ! (
560+ len_now, 3 ,
561+ "batch [6] should still not start (batch 1 still ongoing)"
562+ ) ;
563+ }
564+
565+ // Unblock batch 1 - NOW batch [6] should start
566+ let _ = n1_tx. send ( ( ) ) ;
567+ let v1 = f1. await ??;
568+ assert_eq ! ( v1, 2 ) ;
569+
570+ wait_until_len ( & recorded_calls, 4 ) . await ;
571+
572+ // Unblock batch [6]
573+ let _ = n6_tx. send ( ( ) ) ;
574+ let v6 = f6. await ??;
575+ assert_eq ! ( v6, 12 ) ;
576+
577+ // Validate the call recording
578+ let calls = recorded_calls. lock ( ) . unwrap ( ) . clone ( ) ;
579+ assert_eq ! ( calls. len( ) , 4 ) ;
580+ assert_eq ! ( calls[ 0 ] , vec![ 1 ] ) ;
581+ assert_eq ! ( calls[ 1 ] , vec![ 2 , 3 ] ) ;
582+ assert_eq ! ( calls[ 2 ] , vec![ 4 , 5 ] ) ;
583+ assert_eq ! ( calls[ 3 ] , vec![ 6 ] ) ;
584+
585+ Ok ( ( ) )
586+ }
322587}
0 commit comments