|
3 | 3 | // SPDX-License-Identifier: AGPL-3.0-only |
4 | 4 | // Please see LICENSE in the repository root for full details. |
5 | 5 |
|
6 | | -use chrono::Duration; |
7 | | -use mas_storage::{RepositoryAccess, RepositoryError}; |
| 6 | +use chrono::{DateTime, Duration, Utc}; |
| 7 | +use mas_storage::{queue::Worker, Clock, RepositoryAccess, RepositoryError}; |
| 8 | +use mas_storage_pg::{DatabaseError, PgRepository}; |
| 9 | +use rand::{distributions::Uniform, Rng}; |
| 10 | +use rand_chacha::ChaChaRng; |
| 11 | +use sqlx::PgPool; |
| 12 | +use thiserror::Error; |
8 | 13 |
|
9 | 14 | use crate::State; |
10 | 15 |
|
11 | | -pub async fn run(state: State) -> Result<(), RepositoryError> { |
12 | | - let span = tracing::info_span!("worker.init", worker.id = tracing::field::Empty); |
13 | | - let guard = span.enter(); |
14 | | - let mut repo = state.repository().await?; |
15 | | - let mut rng = state.rng(); |
16 | | - let clock = state.clock(); |
| 16 | +#[derive(Debug, Error)] |
| 17 | +pub enum QueueRunnerError { |
| 18 | + #[error("Failed to setup listener")] |
| 19 | + SetupListener(#[source] sqlx::Error), |
17 | 20 |
|
18 | | - let worker = repo.queue_worker().register(&mut rng, &clock).await?; |
19 | | - span.record("worker.id", tracing::field::display(worker.id)); |
20 | | - repo.save().await?; |
| 21 | + #[error("Failed to start transaction")] |
| 22 | + StartTransaction(#[source] sqlx::Error), |
21 | 23 |
|
22 | | - tracing::info!("Registered worker"); |
23 | | - drop(guard); |
| 24 | + #[error("Failed to commit transaction")] |
| 25 | + CommitTransaction(#[source] sqlx::Error), |
24 | 26 |
|
25 | | - let mut was_i_the_leader = false; |
| 27 | + #[error(transparent)] |
| 28 | + Repository(#[from] RepositoryError), |
26 | 29 |
|
27 | | - // Record when we last sent a heartbeat |
28 | | - let mut last_heartbeat = clock.now(); |
| 30 | + #[error(transparent)] |
| 31 | + Database(#[from] DatabaseError), |
29 | 32 |
|
30 | | - loop { |
| 33 | + #[error("Worker is not the leader")] |
| 34 | + NotLeader, |
| 35 | +} |
| 36 | + |
| 37 | +const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900); |
| 38 | +const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100); |
| 39 | + |
| 40 | +pub struct QueueWorker { |
| 41 | + rng: ChaChaRng, |
| 42 | + clock: Box<dyn Clock + Send>, |
| 43 | + pool: PgPool, |
| 44 | + registration: Worker, |
| 45 | + am_i_leader: bool, |
| 46 | + last_heartbeat: DateTime<Utc>, |
| 47 | +} |
| 48 | + |
| 49 | +impl QueueWorker { |
| 50 | + #[tracing::instrument( |
| 51 | + name = "worker.init", |
| 52 | + skip_all, |
| 53 | + fields(worker.id) |
| 54 | + )] |
| 55 | + pub async fn new(state: State) -> Result<Self, QueueRunnerError> { |
| 56 | + let mut rng = state.rng(); |
| 57 | + let clock = state.clock(); |
| 58 | + let pool = state.pool().clone(); |
| 59 | + |
| 60 | + let txn = pool |
| 61 | + .begin() |
| 62 | + .await |
| 63 | + .map_err(QueueRunnerError::StartTransaction)?; |
| 64 | + let mut repo = PgRepository::from_conn(txn); |
| 65 | + |
| 66 | + let registration = repo.queue_worker().register(&mut rng, &clock).await?; |
| 67 | + tracing::Span::current().record("worker.id", tracing::field::display(registration.id)); |
| 68 | + repo.into_inner() |
| 69 | + .commit() |
| 70 | + .await |
| 71 | + .map_err(QueueRunnerError::CommitTransaction)?; |
| 72 | + |
| 73 | + tracing::info!("Registered worker"); |
| 74 | + let now = clock.now(); |
| 75 | + |
| 76 | + Ok(Self { |
| 77 | + rng, |
| 78 | + clock, |
| 79 | + pool, |
| 80 | + registration, |
| 81 | + am_i_leader: false, |
| 82 | + last_heartbeat: now, |
| 83 | + }) |
| 84 | + } |
| 85 | + |
| 86 | + pub async fn run(&mut self) -> Result<(), QueueRunnerError> { |
| 87 | + loop { |
| 88 | + self.run_loop().await?; |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + #[tracing::instrument(name = "worker.run_loop", skip_all, err)] |
| 93 | + async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { |
| 94 | + self.wait_until_wakeup().await?; |
| 95 | + self.tick().await?; |
| 96 | + |
| 97 | + if self.am_i_leader { |
| 98 | + self.perform_leader_duties().await?; |
| 99 | + } |
| 100 | + |
| 101 | + Ok(()) |
| 102 | + } |
| 103 | + |
| 104 | + #[tracing::instrument(name = "worker.wait_until_wakeup", skip_all, err)] |
| 105 | + async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> { |
31 | 106 | // This is to make sure we wake up every second to do the maintenance tasks |
32 | | - // Later we might wait on other events, like a PG notification |
33 | | - let wakeup_sleep = tokio::time::sleep(std::time::Duration::from_secs(1)); |
34 | | - wakeup_sleep.await; |
| 107 | + // We add a little bit of random jitter to the duration, so that we don't get |
| 108 | + // fully synced workers waking up at the same time after each notification |
| 109 | + let sleep_duration = self |
| 110 | + .rng |
| 111 | + .sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION)); |
| 112 | + tokio::time::sleep(sleep_duration).await; |
| 113 | + tracing::debug!("Woke up from sleep"); |
| 114 | + |
| 115 | + Ok(()) |
| 116 | + } |
35 | 117 |
|
36 | | - let span = tracing::info_span!("worker.tick", %worker.id); |
37 | | - let _guard = span.enter(); |
| 118 | + fn set_new_leader_state(&mut self, state: bool) { |
| 119 | + // Do nothing if we were already on that state |
| 120 | + if state == self.am_i_leader { |
| 121 | + return; |
| 122 | + } |
38 | 123 |
|
| 124 | + // If we flipped state, log it |
| 125 | + self.am_i_leader = state; |
| 126 | + if self.am_i_leader { |
| 127 | + tracing::info!("I'm the leader now"); |
| 128 | + } else { |
| 129 | + tracing::warn!("I am no longer the leader"); |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + #[tracing::instrument( |
| 134 | + name = "worker.tick", |
| 135 | + skip_all, |
| 136 | + fields(worker.id = %self.registration.id), |
| 137 | + err, |
| 138 | + )] |
| 139 | + async fn tick(&mut self) -> Result<(), QueueRunnerError> { |
39 | 140 | tracing::debug!("Tick"); |
40 | | - let now = clock.now(); |
41 | | - let mut repo = state.repository().await?; |
| 141 | + let now = self.clock.now(); |
| 142 | + |
| 143 | + let txn = self |
| 144 | + .pool |
| 145 | + .begin() |
| 146 | + .await |
| 147 | + .map_err(QueueRunnerError::StartTransaction)?; |
| 148 | + let mut repo = PgRepository::from_conn(txn); |
42 | 149 |
|
43 | 150 | // We send a heartbeat every minute, to avoid writing to the database too often |
44 | 151 | // on a logged table |
45 | | - if now - last_heartbeat >= chrono::Duration::minutes(1) { |
| 152 | + if now - self.last_heartbeat >= chrono::Duration::minutes(1) { |
46 | 153 | tracing::info!("Sending heartbeat"); |
47 | | - repo.queue_worker().heartbeat(&clock, &worker).await?; |
48 | | - last_heartbeat = now; |
| 154 | + repo.queue_worker() |
| 155 | + .heartbeat(&self.clock, &self.registration) |
| 156 | + .await?; |
| 157 | + self.last_heartbeat = now; |
49 | 158 | } |
50 | 159 |
|
51 | 160 | // Remove any dead worker leader leases |
52 | 161 | repo.queue_worker() |
53 | | - .remove_leader_lease_if_expired(&clock) |
| 162 | + .remove_leader_lease_if_expired(&self.clock) |
54 | 163 | .await?; |
55 | 164 |
|
56 | 165 | // Try to become (or stay) the leader |
57 | | - let am_i_the_leader = repo |
| 166 | + let leader = repo |
58 | 167 | .queue_worker() |
59 | | - .try_get_leader_lease(&clock, &worker) |
| 168 | + .try_get_leader_lease(&self.clock, &self.registration) |
60 | 169 | .await?; |
61 | 170 |
|
62 | | - // Log any changes in leadership |
63 | | - if !was_i_the_leader && am_i_the_leader { |
64 | | - tracing::info!("I'm the leader now"); |
65 | | - } else if was_i_the_leader && !am_i_the_leader { |
66 | | - tracing::warn!("I am no longer the leader"); |
67 | | - } |
68 | | - was_i_the_leader = am_i_the_leader; |
| 171 | + repo.into_inner() |
| 172 | + .commit() |
| 173 | + .await |
| 174 | + .map_err(QueueRunnerError::CommitTransaction)?; |
69 | 175 |
|
70 | | - // The leader does all the maintenance work |
71 | | - if am_i_the_leader { |
72 | | - // We also check if the worker is dead, and if so, we shutdown all the dead |
73 | | - // workers that haven't checked in the last two minutes |
74 | | - repo.queue_worker() |
75 | | - .shutdown_dead_workers(&clock, Duration::minutes(2)) |
76 | | - .await?; |
| 176 | + // Save the new leader state |
| 177 | + self.set_new_leader_state(leader); |
| 178 | + |
| 179 | + Ok(()) |
| 180 | + } |
| 181 | + |
| 182 | + #[tracing::instrument(name = "worker.perform_leader_duties", skip_all, err)] |
| 183 | + async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> { |
| 184 | + // This should have been checked by the caller, but better safe than sorry |
| 185 | + if !self.am_i_leader { |
| 186 | + return Err(QueueRunnerError::NotLeader); |
77 | 187 | } |
78 | 188 |
|
79 | | - repo.save().await?; |
| 189 | + let txn = self |
| 190 | + .pool |
| 191 | + .begin() |
| 192 | + .await |
| 193 | + .map_err(QueueRunnerError::StartTransaction)?; |
| 194 | + let mut repo = PgRepository::from_conn(txn); |
| 195 | + |
| 196 | + // We also check if the worker is dead, and if so, we shutdown all the dead |
| 197 | + // workers that haven't checked in the last two minutes |
| 198 | + repo.queue_worker() |
| 199 | + .shutdown_dead_workers(&self.clock, Duration::minutes(2)) |
| 200 | + .await?; |
| 201 | + |
| 202 | + repo.into_inner() |
| 203 | + .commit() |
| 204 | + .await |
| 205 | + .map_err(QueueRunnerError::CommitTransaction)?; |
| 206 | + |
| 207 | + Ok(()) |
80 | 208 | } |
81 | 209 | } |
0 commit comments