Skip to content

Commit 05e9116

Browse files
authored
feat(batching): implement batching util library (#1229)
1 parent 0d738ce commit 05e9116

File tree

3 files changed

+333
-4
lines changed

3 files changed

+333
-4
lines changed

src/service/error.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ pub struct ResidualErrorData {
7979
#[derive(Clone)]
8080
pub struct ResidualError(Arc<ResidualErrorData>);
8181

82+
impl ResidualError {
83+
pub fn new<Err: Display + Debug>(err: &Err) -> Self {
84+
Self(Arc::new(ResidualErrorData {
85+
message: err.to_string(),
86+
debug: err.to_string(),
87+
}))
88+
}
89+
}
90+
8291
impl Display for ResidualError {
8392
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
8493
write!(f, "{}", self.0.message)
@@ -116,10 +125,7 @@ impl SharedError {
116125
SharedErrorState::ResidualErrorMessage(err) => {
117126
return anyhow::Error::from(err.clone());
118127
}
119-
SharedErrorState::Anyhow(err) => ResidualError(Arc::new(ResidualErrorData {
120-
message: format!("{}", err),
121-
debug: format!("{:?}", err),
122-
})),
128+
SharedErrorState::Anyhow(err) => ResidualError::new(err),
123129
};
124130
let orig_state = std::mem::replace(
125131
mut_state,

src/utils/batching.rs

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
use crate::{prelude::*, service::error::ResidualError};
2+
use tokio::sync::{oneshot, watch};
3+
use tokio_util::task::AbortOnDropHandle;
4+
5+
#[async_trait]
6+
pub trait Runner: Send + Sync {
7+
type Input: Send;
8+
type Output: Send;
9+
10+
async fn run(
11+
&self,
12+
inputs: Vec<Self::Input>,
13+
) -> Result<impl ExactSizeIterator<Item = Self::Output>>;
14+
}
15+
16+
struct Batch<I, O> {
17+
inputs: Vec<I>,
18+
output_txs: Vec<oneshot::Sender<Result<O>>>,
19+
num_cancelled_tx: watch::Sender<usize>,
20+
num_cancelled_rx: watch::Receiver<usize>,
21+
}
22+
23+
impl<I, O> Default for Batch<I, O> {
24+
fn default() -> Self {
25+
let (num_cancelled_tx, num_cancelled_rx) = watch::channel(0);
26+
Self {
27+
inputs: Vec::new(),
28+
output_txs: Vec::new(),
29+
num_cancelled_tx,
30+
num_cancelled_rx,
31+
}
32+
}
33+
}
34+
35+
#[derive(Default)]
36+
enum BatcherState<I, O> {
37+
#[default]
38+
Idle,
39+
Busy(Option<Batch<I, O>>),
40+
}
41+
42+
struct BatcherData<R: Runner + 'static> {
43+
runner: R,
44+
state: Mutex<BatcherState<R::Input, R::Output>>,
45+
}
46+
47+
impl<R: Runner + 'static> BatcherData<R> {
48+
async fn run_batch(self: &Arc<Self>, batch: Batch<R::Input, R::Output>) {
49+
let _kick_off_next = BatchKickOffNext { batcher_data: self };
50+
let num_inputs = batch.inputs.len();
51+
52+
let mut num_cancelled_rx = batch.num_cancelled_rx;
53+
let outputs = tokio::select! {
54+
outputs = self.runner.run(batch.inputs) => {
55+
outputs
56+
}
57+
_ = num_cancelled_rx.wait_for(|v| *v == num_inputs) => {
58+
return;
59+
}
60+
};
61+
62+
match outputs {
63+
Ok(outputs) => {
64+
if outputs.len() != batch.output_txs.len() {
65+
let message = format!(
66+
"Batched output length mismatch: expected {} outputs, got {}",
67+
batch.output_txs.len(),
68+
outputs.len()
69+
);
70+
error!("{message}");
71+
for sender in batch.output_txs {
72+
sender.send(Err(anyhow!("{message}"))).ok();
73+
}
74+
return;
75+
}
76+
for (output, sender) in outputs.zip(batch.output_txs) {
77+
sender.send(Ok(output)).ok();
78+
}
79+
}
80+
Err(err) => {
81+
let mut senders_iter = batch.output_txs.into_iter();
82+
if let Some(sender) = senders_iter.next() {
83+
if senders_iter.len() > 0 {
84+
let residual_err = ResidualError::new(&err);
85+
for sender in senders_iter {
86+
sender.send(Err(residual_err.clone().into())).ok();
87+
}
88+
}
89+
sender.send(Err(err)).ok();
90+
}
91+
}
92+
}
93+
}
94+
}
95+
96+
pub struct Batcher<R: Runner + 'static> {
97+
data: Arc<BatcherData<R>>,
98+
}
99+
100+
enum BatchExecutionAction<R: Runner + 'static> {
101+
Inline {
102+
input: R::Input,
103+
},
104+
Batched {
105+
output_rx: oneshot::Receiver<Result<R::Output>>,
106+
num_cancelled_tx: watch::Sender<usize>,
107+
},
108+
}
109+
impl<R: Runner + 'static> Batcher<R> {
110+
pub fn new(runner: R) -> Self {
111+
Self {
112+
data: Arc::new(BatcherData {
113+
runner,
114+
state: Mutex::new(BatcherState::Idle),
115+
}),
116+
}
117+
}
118+
pub async fn run(&self, input: R::Input) -> Result<R::Output> {
119+
let batch_exec_action: BatchExecutionAction<R> = {
120+
let mut state = self.data.state.lock().unwrap();
121+
match &mut *state {
122+
state @ BatcherState::Idle => {
123+
*state = BatcherState::Busy(None);
124+
BatchExecutionAction::Inline { input }
125+
}
126+
BatcherState::Busy(batch) => {
127+
let batch = batch.get_or_insert_default();
128+
batch.inputs.push(input);
129+
130+
let (output_tx, output_rx) = oneshot::channel();
131+
batch.output_txs.push(output_tx);
132+
133+
BatchExecutionAction::Batched {
134+
output_rx,
135+
num_cancelled_tx: batch.num_cancelled_tx.clone(),
136+
}
137+
}
138+
}
139+
};
140+
match batch_exec_action {
141+
BatchExecutionAction::Inline { input } => {
142+
let _kick_off_next = BatchKickOffNext {
143+
batcher_data: &self.data,
144+
};
145+
146+
let data = self.data.clone();
147+
let handle = AbortOnDropHandle::new(tokio::spawn(async move {
148+
let mut outputs = data.runner.run(vec![input]).await?;
149+
if outputs.len() != 1 {
150+
bail!("Expected 1 output, got {}", outputs.len());
151+
}
152+
Ok(outputs.next().unwrap())
153+
}));
154+
Ok(handle.await??)
155+
}
156+
BatchExecutionAction::Batched {
157+
output_rx,
158+
num_cancelled_tx,
159+
} => {
160+
let mut guard = BatchRecvCancellationGuard::new(Some(num_cancelled_tx));
161+
let output = output_rx.await?;
162+
guard.done();
163+
output
164+
}
165+
}
166+
}
167+
}
168+
169+
struct BatchKickOffNext<'a, R: Runner + 'static> {
170+
batcher_data: &'a Arc<BatcherData<R>>,
171+
}
172+
173+
impl<'a, R: Runner + 'static> Drop for BatchKickOffNext<'a, R> {
174+
fn drop(&mut self) {
175+
let mut state = self.batcher_data.state.lock().unwrap();
176+
let existing_state = std::mem::take(&mut *state);
177+
let BatcherState::Busy(Some(batch)) = existing_state else {
178+
return;
179+
};
180+
*state = BatcherState::Busy(None);
181+
let data = self.batcher_data.clone();
182+
tokio::spawn(async move { data.run_batch(batch).await });
183+
}
184+
}
185+
186+
struct BatchRecvCancellationGuard {
187+
num_cancelled_tx: Option<watch::Sender<usize>>,
188+
}
189+
190+
impl Drop for BatchRecvCancellationGuard {
191+
fn drop(&mut self) {
192+
if let Some(num_cancelled_tx) = self.num_cancelled_tx.take() {
193+
num_cancelled_tx.send_modify(|v| *v += 1);
194+
}
195+
}
196+
}
197+
198+
impl BatchRecvCancellationGuard {
199+
pub fn new(num_cancelled_tx: Option<watch::Sender<usize>>) -> Self {
200+
Self { num_cancelled_tx }
201+
}
202+
203+
pub fn done(&mut self) {
204+
self.num_cancelled_tx = None;
205+
}
206+
}
207+
208+
#[cfg(test)]
209+
mod tests {
210+
use super::*;
211+
use std::sync::{Arc, Mutex};
212+
use tokio::sync::oneshot;
213+
use tokio::time::{Duration, sleep};
214+
215+
struct TestRunner {
216+
// Records each call's input values as a vector, in call order
217+
recorded_calls: Arc<Mutex<Vec<Vec<i64>>>>,
218+
}
219+
220+
#[async_trait]
221+
impl Runner for TestRunner {
222+
type Input = (i64, oneshot::Receiver<()>);
223+
type Output = i64;
224+
225+
async fn run(
226+
&self,
227+
inputs: Vec<Self::Input>,
228+
) -> Result<impl ExactSizeIterator<Item = Self::Output>> {
229+
// Record the values for this invocation (order-agnostic)
230+
let mut values: Vec<i64> = inputs.iter().map(|(v, _)| *v).collect();
231+
values.sort();
232+
self.recorded_calls.lock().unwrap().push(values);
233+
234+
// Split into values and receivers so we can await by value (send-before-wait safe)
235+
let (vals, rxs): (Vec<i64>, Vec<oneshot::Receiver<()>>) =
236+
inputs.into_iter().map(|(v, rx)| (v, rx)).unzip();
237+
238+
// Block until every input's signal is fired
239+
for (_i, rx) in rxs.into_iter().enumerate() {
240+
let _ = rx.await;
241+
}
242+
243+
// Return outputs mapping v -> v * 2
244+
let outputs: Vec<i64> = vals.into_iter().map(|v| v * 2).collect();
245+
Ok(outputs.into_iter())
246+
}
247+
}
248+
249+
async fn wait_until_len(recorded: &Arc<Mutex<Vec<Vec<i64>>>>, expected_len: usize) {
250+
for _ in 0..200 {
251+
// up to ~2s
252+
if recorded.lock().unwrap().len() == expected_len {
253+
return;
254+
}
255+
sleep(Duration::from_millis(10)).await;
256+
}
257+
panic!("timed out waiting for recorded_calls length {expected_len}");
258+
}
259+
260+
#[tokio::test(flavor = "current_thread")]
261+
async fn batches_after_first_inline_call() -> Result<()> {
262+
let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new()));
263+
let runner = TestRunner {
264+
recorded_calls: recorded_calls.clone(),
265+
};
266+
let batcher = Arc::new(Batcher::new(runner));
267+
268+
let (n1_tx, n1_rx) = oneshot::channel::<()>();
269+
let (n2_tx, n2_rx) = oneshot::channel::<()>();
270+
let (n3_tx, n3_rx) = oneshot::channel::<()>();
271+
272+
// Submit first call; it should execute inline and block on n1
273+
let b1 = batcher.clone();
274+
let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await });
275+
276+
// Wait until the runner has recorded the first inline call
277+
wait_until_len(&recorded_calls, 1).await;
278+
279+
// Submit the next two calls; they should be batched together and not run yet
280+
let b2 = batcher.clone();
281+
let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await });
282+
283+
let b3 = batcher.clone();
284+
let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await });
285+
286+
// Ensure no new batch has started yet
287+
{
288+
let len_now = recorded_calls.lock().unwrap().len();
289+
assert_eq!(
290+
len_now, 1,
291+
"second invocation should not have started before unblocking first"
292+
);
293+
}
294+
295+
// Unblock the first call; this should trigger the next batch of [2,3]
296+
let _ = n1_tx.send(());
297+
298+
// Wait for the batch call to be recorded
299+
wait_until_len(&recorded_calls, 2).await;
300+
301+
// First result should now be available
302+
let v1 = f1.await??;
303+
assert_eq!(v1, 2);
304+
305+
// The batched call is waiting on n2 and n3; now unblock both and collect results
306+
let _ = n2_tx.send(());
307+
let _ = n3_tx.send(());
308+
309+
let v2 = f2.await??;
310+
let v3 = f3.await??;
311+
assert_eq!(v2, 4);
312+
assert_eq!(v3, 6);
313+
314+
// Validate the call recording: first [1], then [2, 3]
315+
let calls = recorded_calls.lock().unwrap().clone();
316+
assert_eq!(calls.len(), 2);
317+
assert_eq!(calls[0], vec![1]);
318+
assert_eq!(calls[1], vec![2, 3]);
319+
320+
Ok(())
321+
}
322+
}

src/utils/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod batching;
12
pub mod bytes_decode;
23
pub mod concur_control;
34
pub mod db;

0 commit comments

Comments
 (0)