|
| 1 | +use crate::{prelude::*, service::error::ResidualError}; |
| 2 | +use tokio::sync::{oneshot, watch}; |
| 3 | +use tokio_util::task::AbortOnDropHandle; |
| 4 | + |
| 5 | +#[async_trait] |
| 6 | +pub trait Runner: Send + Sync { |
| 7 | + type Input: Send; |
| 8 | + type Output: Send; |
| 9 | + |
| 10 | + async fn run( |
| 11 | + &self, |
| 12 | + inputs: Vec<Self::Input>, |
| 13 | + ) -> Result<impl ExactSizeIterator<Item = Self::Output>>; |
| 14 | +} |
| 15 | + |
| 16 | +struct Batch<I, O> { |
| 17 | + inputs: Vec<I>, |
| 18 | + output_txs: Vec<oneshot::Sender<Result<O>>>, |
| 19 | + num_cancelled_tx: watch::Sender<usize>, |
| 20 | + num_cancelled_rx: watch::Receiver<usize>, |
| 21 | +} |
| 22 | + |
| 23 | +impl<I, O> Default for Batch<I, O> { |
| 24 | + fn default() -> Self { |
| 25 | + let (num_cancelled_tx, num_cancelled_rx) = watch::channel(0); |
| 26 | + Self { |
| 27 | + inputs: Vec::new(), |
| 28 | + output_txs: Vec::new(), |
| 29 | + num_cancelled_tx, |
| 30 | + num_cancelled_rx, |
| 31 | + } |
| 32 | + } |
| 33 | +} |
| 34 | + |
| 35 | +#[derive(Default)] |
| 36 | +enum BatcherState<I, O> { |
| 37 | + #[default] |
| 38 | + Idle, |
| 39 | + Busy(Option<Batch<I, O>>), |
| 40 | +} |
| 41 | + |
| 42 | +struct BatcherData<R: Runner + 'static> { |
| 43 | + runner: R, |
| 44 | + state: Mutex<BatcherState<R::Input, R::Output>>, |
| 45 | +} |
| 46 | + |
| 47 | +impl<R: Runner + 'static> BatcherData<R> { |
| 48 | + async fn run_batch(self: &Arc<Self>, batch: Batch<R::Input, R::Output>) { |
| 49 | + let _kick_off_next = BatchKickOffNext { batcher_data: self }; |
| 50 | + let num_inputs = batch.inputs.len(); |
| 51 | + |
| 52 | + let mut num_cancelled_rx = batch.num_cancelled_rx; |
| 53 | + let outputs = tokio::select! { |
| 54 | + outputs = self.runner.run(batch.inputs) => { |
| 55 | + outputs |
| 56 | + } |
| 57 | + _ = num_cancelled_rx.wait_for(|v| *v == num_inputs) => { |
| 58 | + return; |
| 59 | + } |
| 60 | + }; |
| 61 | + |
| 62 | + match outputs { |
| 63 | + Ok(outputs) => { |
| 64 | + if outputs.len() != batch.output_txs.len() { |
| 65 | + let message = format!( |
| 66 | + "Batched output length mismatch: expected {} outputs, got {}", |
| 67 | + batch.output_txs.len(), |
| 68 | + outputs.len() |
| 69 | + ); |
| 70 | + error!("{message}"); |
| 71 | + for sender in batch.output_txs { |
| 72 | + sender.send(Err(anyhow!("{message}"))).ok(); |
| 73 | + } |
| 74 | + return; |
| 75 | + } |
| 76 | + for (output, sender) in outputs.zip(batch.output_txs) { |
| 77 | + sender.send(Ok(output)).ok(); |
| 78 | + } |
| 79 | + } |
| 80 | + Err(err) => { |
| 81 | + let mut senders_iter = batch.output_txs.into_iter(); |
| 82 | + if let Some(sender) = senders_iter.next() { |
| 83 | + if senders_iter.len() > 0 { |
| 84 | + let residual_err = ResidualError::new(&err); |
| 85 | + for sender in senders_iter { |
| 86 | + sender.send(Err(residual_err.clone().into())).ok(); |
| 87 | + } |
| 88 | + } |
| 89 | + sender.send(Err(err)).ok(); |
| 90 | + } |
| 91 | + } |
| 92 | + } |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +pub struct Batcher<R: Runner + 'static> { |
| 97 | + data: Arc<BatcherData<R>>, |
| 98 | +} |
| 99 | + |
| 100 | +enum BatchExecutionAction<R: Runner + 'static> { |
| 101 | + Inline { |
| 102 | + input: R::Input, |
| 103 | + }, |
| 104 | + Batched { |
| 105 | + output_rx: oneshot::Receiver<Result<R::Output>>, |
| 106 | + num_cancelled_tx: watch::Sender<usize>, |
| 107 | + }, |
| 108 | +} |
| 109 | +impl<R: Runner + 'static> Batcher<R> { |
| 110 | + pub fn new(runner: R) -> Self { |
| 111 | + Self { |
| 112 | + data: Arc::new(BatcherData { |
| 113 | + runner, |
| 114 | + state: Mutex::new(BatcherState::Idle), |
| 115 | + }), |
| 116 | + } |
| 117 | + } |
| 118 | + pub async fn run(&self, input: R::Input) -> Result<R::Output> { |
| 119 | + let batch_exec_action: BatchExecutionAction<R> = { |
| 120 | + let mut state = self.data.state.lock().unwrap(); |
| 121 | + match &mut *state { |
| 122 | + state @ BatcherState::Idle => { |
| 123 | + *state = BatcherState::Busy(None); |
| 124 | + BatchExecutionAction::Inline { input } |
| 125 | + } |
| 126 | + BatcherState::Busy(batch) => { |
| 127 | + let batch = batch.get_or_insert_default(); |
| 128 | + batch.inputs.push(input); |
| 129 | + |
| 130 | + let (output_tx, output_rx) = oneshot::channel(); |
| 131 | + batch.output_txs.push(output_tx); |
| 132 | + |
| 133 | + BatchExecutionAction::Batched { |
| 134 | + output_rx, |
| 135 | + num_cancelled_tx: batch.num_cancelled_tx.clone(), |
| 136 | + } |
| 137 | + } |
| 138 | + } |
| 139 | + }; |
| 140 | + match batch_exec_action { |
| 141 | + BatchExecutionAction::Inline { input } => { |
| 142 | + let _kick_off_next = BatchKickOffNext { |
| 143 | + batcher_data: &self.data, |
| 144 | + }; |
| 145 | + |
| 146 | + let data = self.data.clone(); |
| 147 | + let handle = AbortOnDropHandle::new(tokio::spawn(async move { |
| 148 | + let mut outputs = data.runner.run(vec![input]).await?; |
| 149 | + if outputs.len() != 1 { |
| 150 | + bail!("Expected 1 output, got {}", outputs.len()); |
| 151 | + } |
| 152 | + Ok(outputs.next().unwrap()) |
| 153 | + })); |
| 154 | + Ok(handle.await??) |
| 155 | + } |
| 156 | + BatchExecutionAction::Batched { |
| 157 | + output_rx, |
| 158 | + num_cancelled_tx, |
| 159 | + } => { |
| 160 | + let mut guard = BatchRecvCancellationGuard::new(Some(num_cancelled_tx)); |
| 161 | + let output = output_rx.await?; |
| 162 | + guard.done(); |
| 163 | + output |
| 164 | + } |
| 165 | + } |
| 166 | + } |
| 167 | +} |
| 168 | + |
| 169 | +struct BatchKickOffNext<'a, R: Runner + 'static> { |
| 170 | + batcher_data: &'a Arc<BatcherData<R>>, |
| 171 | +} |
| 172 | + |
| 173 | +impl<'a, R: Runner + 'static> Drop for BatchKickOffNext<'a, R> { |
| 174 | + fn drop(&mut self) { |
| 175 | + let mut state = self.batcher_data.state.lock().unwrap(); |
| 176 | + let existing_state = std::mem::take(&mut *state); |
| 177 | + let BatcherState::Busy(Some(batch)) = existing_state else { |
| 178 | + return; |
| 179 | + }; |
| 180 | + *state = BatcherState::Busy(None); |
| 181 | + let data = self.batcher_data.clone(); |
| 182 | + tokio::spawn(async move { data.run_batch(batch).await }); |
| 183 | + } |
| 184 | +} |
| 185 | + |
| 186 | +struct BatchRecvCancellationGuard { |
| 187 | + num_cancelled_tx: Option<watch::Sender<usize>>, |
| 188 | +} |
| 189 | + |
| 190 | +impl Drop for BatchRecvCancellationGuard { |
| 191 | + fn drop(&mut self) { |
| 192 | + if let Some(num_cancelled_tx) = self.num_cancelled_tx.take() { |
| 193 | + num_cancelled_tx.send_modify(|v| *v += 1); |
| 194 | + } |
| 195 | + } |
| 196 | +} |
| 197 | + |
| 198 | +impl BatchRecvCancellationGuard { |
| 199 | + pub fn new(num_cancelled_tx: Option<watch::Sender<usize>>) -> Self { |
| 200 | + Self { num_cancelled_tx } |
| 201 | + } |
| 202 | + |
| 203 | + pub fn done(&mut self) { |
| 204 | + self.num_cancelled_tx = None; |
| 205 | + } |
| 206 | +} |
| 207 | + |
| 208 | +#[cfg(test)] |
| 209 | +mod tests { |
| 210 | + use super::*; |
| 211 | + use std::sync::{Arc, Mutex}; |
| 212 | + use tokio::sync::oneshot; |
| 213 | + use tokio::time::{Duration, sleep}; |
| 214 | + |
| 215 | + struct TestRunner { |
| 216 | + // Records each call's input values as a vector, in call order |
| 217 | + recorded_calls: Arc<Mutex<Vec<Vec<i64>>>>, |
| 218 | + } |
| 219 | + |
| 220 | + #[async_trait] |
| 221 | + impl Runner for TestRunner { |
| 222 | + type Input = (i64, oneshot::Receiver<()>); |
| 223 | + type Output = i64; |
| 224 | + |
| 225 | + async fn run( |
| 226 | + &self, |
| 227 | + inputs: Vec<Self::Input>, |
| 228 | + ) -> Result<impl ExactSizeIterator<Item = Self::Output>> { |
| 229 | + // Record the values for this invocation (order-agnostic) |
| 230 | + let mut values: Vec<i64> = inputs.iter().map(|(v, _)| *v).collect(); |
| 231 | + values.sort(); |
| 232 | + self.recorded_calls.lock().unwrap().push(values); |
| 233 | + |
| 234 | + // Split into values and receivers so we can await by value (send-before-wait safe) |
| 235 | + let (vals, rxs): (Vec<i64>, Vec<oneshot::Receiver<()>>) = |
| 236 | + inputs.into_iter().map(|(v, rx)| (v, rx)).unzip(); |
| 237 | + |
| 238 | + // Block until every input's signal is fired |
| 239 | + for (_i, rx) in rxs.into_iter().enumerate() { |
| 240 | + let _ = rx.await; |
| 241 | + } |
| 242 | + |
| 243 | + // Return outputs mapping v -> v * 2 |
| 244 | + let outputs: Vec<i64> = vals.into_iter().map(|v| v * 2).collect(); |
| 245 | + Ok(outputs.into_iter()) |
| 246 | + } |
| 247 | + } |
| 248 | + |
| 249 | + async fn wait_until_len(recorded: &Arc<Mutex<Vec<Vec<i64>>>>, expected_len: usize) { |
| 250 | + for _ in 0..200 { |
| 251 | + // up to ~2s |
| 252 | + if recorded.lock().unwrap().len() == expected_len { |
| 253 | + return; |
| 254 | + } |
| 255 | + sleep(Duration::from_millis(10)).await; |
| 256 | + } |
| 257 | + panic!("timed out waiting for recorded_calls length {expected_len}"); |
| 258 | + } |
| 259 | + |
| 260 | + #[tokio::test(flavor = "current_thread")] |
| 261 | + async fn batches_after_first_inline_call() -> Result<()> { |
| 262 | + let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new())); |
| 263 | + let runner = TestRunner { |
| 264 | + recorded_calls: recorded_calls.clone(), |
| 265 | + }; |
| 266 | + let batcher = Arc::new(Batcher::new(runner)); |
| 267 | + |
| 268 | + let (n1_tx, n1_rx) = oneshot::channel::<()>(); |
| 269 | + let (n2_tx, n2_rx) = oneshot::channel::<()>(); |
| 270 | + let (n3_tx, n3_rx) = oneshot::channel::<()>(); |
| 271 | + |
| 272 | + // Submit first call; it should execute inline and block on n1 |
| 273 | + let b1 = batcher.clone(); |
| 274 | + let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await }); |
| 275 | + |
| 276 | + // Wait until the runner has recorded the first inline call |
| 277 | + wait_until_len(&recorded_calls, 1).await; |
| 278 | + |
| 279 | + // Submit the next two calls; they should be batched together and not run yet |
| 280 | + let b2 = batcher.clone(); |
| 281 | + let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await }); |
| 282 | + |
| 283 | + let b3 = batcher.clone(); |
| 284 | + let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await }); |
| 285 | + |
| 286 | + // Ensure no new batch has started yet |
| 287 | + { |
| 288 | + let len_now = recorded_calls.lock().unwrap().len(); |
| 289 | + assert_eq!( |
| 290 | + len_now, 1, |
| 291 | + "second invocation should not have started before unblocking first" |
| 292 | + ); |
| 293 | + } |
| 294 | + |
| 295 | + // Unblock the first call; this should trigger the next batch of [2,3] |
| 296 | + let _ = n1_tx.send(()); |
| 297 | + |
| 298 | + // Wait for the batch call to be recorded |
| 299 | + wait_until_len(&recorded_calls, 2).await; |
| 300 | + |
| 301 | + // First result should now be available |
| 302 | + let v1 = f1.await??; |
| 303 | + assert_eq!(v1, 2); |
| 304 | + |
| 305 | + // The batched call is waiting on n2 and n3; now unblock both and collect results |
| 306 | + let _ = n2_tx.send(()); |
| 307 | + let _ = n3_tx.send(()); |
| 308 | + |
| 309 | + let v2 = f2.await??; |
| 310 | + let v3 = f3.await??; |
| 311 | + assert_eq!(v2, 4); |
| 312 | + assert_eq!(v3, 6); |
| 313 | + |
| 314 | + // Validate the call recording: first [1], then [2, 3] |
| 315 | + let calls = recorded_calls.lock().unwrap().clone(); |
| 316 | + assert_eq!(calls.len(), 2); |
| 317 | + assert_eq!(calls[0], vec![1]); |
| 318 | + assert_eq!(calls[1], vec![2, 3]); |
| 319 | + |
| 320 | + Ok(()) |
| 321 | + } |
| 322 | +} |
0 commit comments