|
| 1 | +use std::{ |
| 2 | + collections::HashMap, |
| 3 | + sync::{Arc, atomic::AtomicUsize}, |
| 4 | + time::Duration, |
| 5 | +}; |
| 6 | + |
| 7 | +use tokio::sync::{mpsc, oneshot}; |
| 8 | + |
| 9 | +use crate::clock::{ |
| 10 | + Clock, TaskInitiator, Timestamp, coordinator::ClockEvent, timestamp::AtomicTimestamp, |
| 11 | +}; |
| 12 | + |
| 13 | +pub struct MockClockCoordinator { |
| 14 | + time: Arc<AtomicTimestamp>, |
| 15 | + tx: mpsc::UnboundedSender<ClockEvent>, |
| 16 | + rx: mpsc::UnboundedReceiver<ClockEvent>, |
| 17 | + waiter_count: Arc<AtomicUsize>, |
| 18 | + tasks: Arc<AtomicUsize>, |
| 19 | + waiters: HashMap<usize, Waiter>, |
| 20 | +} |
| 21 | + |
| 22 | +impl Default for MockClockCoordinator { |
| 23 | + fn default() -> Self { |
| 24 | + Self::new() |
| 25 | + } |
| 26 | +} |
| 27 | + |
| 28 | +impl MockClockCoordinator { |
| 29 | + pub fn new() -> Self { |
| 30 | + let time = Arc::new(AtomicTimestamp::new(Timestamp::zero())); |
| 31 | + let (tx, rx) = mpsc::unbounded_channel(); |
| 32 | + let waiter_count = Arc::new(AtomicUsize::new(0)); |
| 33 | + let tasks = Arc::new(AtomicUsize::new(0)); |
| 34 | + Self { |
| 35 | + time, |
| 36 | + tx, |
| 37 | + rx, |
| 38 | + waiter_count, |
| 39 | + tasks, |
| 40 | + waiters: HashMap::new(), |
| 41 | + } |
| 42 | + } |
| 43 | + |
| 44 | + pub fn clock(&self) -> Clock { |
| 45 | + Clock::new( |
| 46 | + Duration::from_nanos(1), |
| 47 | + self.time.clone(), |
| 48 | + self.waiter_count.clone(), |
| 49 | + TaskInitiator::new(self.tasks.clone()), |
| 50 | + self.tx.clone(), |
| 51 | + ) |
| 52 | + } |
| 53 | + |
| 54 | + pub fn now(&self) -> Timestamp { |
| 55 | + self.time.load(std::sync::atomic::Ordering::Acquire) |
| 56 | + } |
| 57 | + |
| 58 | + pub fn advance_time(&mut self, until: Timestamp) { |
| 59 | + while let Ok(event) = self.rx.try_recv() { |
| 60 | + match event { |
| 61 | + ClockEvent::Wait { actor, until, done } => { |
| 62 | + if self.waiters.insert(actor, Waiter { until, done }).is_some() { |
| 63 | + panic!("waiter {actor} waited twice"); |
| 64 | + } |
| 65 | + } |
| 66 | + ClockEvent::CancelWait { actor } => { |
| 67 | + if self.waiters.remove(&actor).is_none() { |
| 68 | + panic!("waiter {actor} cancelled a wait twice"); |
| 69 | + } |
| 70 | + } |
| 71 | + ClockEvent::FinishTask => { |
| 72 | + if self.tasks.fetch_sub(1, std::sync::atomic::Ordering::AcqRel) == 0 { |
| 73 | + panic!("cancelled too many tasks"); |
| 74 | + } |
| 75 | + } |
| 76 | + } |
| 77 | + } |
| 78 | + assert_eq!( |
| 79 | + self.waiters.len(), |
| 80 | + self.waiter_count.load(std::sync::atomic::Ordering::Acquire), |
| 81 | + "not every worker is waiting for time to pass" |
| 82 | + ); |
| 83 | + |
| 84 | + self.time.store(until, std::sync::atomic::Ordering::Release); |
| 85 | + self.waiters = std::mem::take(&mut self.waiters) |
| 86 | + .into_iter() |
| 87 | + .filter_map(|(actor, waiter)| { |
| 88 | + if let Some(t) = &waiter.until { |
| 89 | + if *t < until { |
| 90 | + panic!("advanced time too far (waited for {until:?}, next event at {t:?})"); |
| 91 | + } |
| 92 | + if *t == until { |
| 93 | + let _ = waiter.done.send(()); |
| 94 | + return None; |
| 95 | + } |
| 96 | + } |
| 97 | + Some((actor, waiter)) |
| 98 | + }) |
| 99 | + .collect(); |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +struct Waiter { |
| 104 | + until: Option<Timestamp>, |
| 105 | + done: oneshot::Sender<()>, |
| 106 | +} |
0 commit comments