Skip to content

Commit 2801285

Browse files
committed
Move the worker logic in a struct
1 parent 823174c commit 2801285

File tree

2 files changed

+183
-48
lines changed

2 files changed

+183
-48
lines changed

crates/tasks/src/lib.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use mas_matrix::HomeserverConnection;
1212
use mas_router::UrlBuilder;
1313
use mas_storage::{BoxClock, BoxRepository, RepositoryError, SystemClock};
1414
use mas_storage_pg::PgRepository;
15+
use new_queue::QueueRunnerError;
1516
use rand::SeedableRng;
1617
use sqlx::{Pool, Postgres};
1718
use tracing::debug;
@@ -142,7 +143,7 @@ pub async fn init(
142143
mailer: &Mailer,
143144
homeserver: impl HomeserverConnection<Error = anyhow::Error> + 'static,
144145
url_builder: UrlBuilder,
145-
) -> Result<Monitor<TokioExecutor>, sqlx::Error> {
146+
) -> Result<Monitor<TokioExecutor>, QueueRunnerError> {
146147
let state = State::new(
147148
pool.clone(),
148149
SystemClock::default(),
@@ -158,13 +159,19 @@ pub async fn init(
158159
let monitor = self::user::register(name, monitor, &state, &factory);
159160
let monitor = self::recovery::register(name, monitor, &state, &factory);
160161
// TODO: we might want to grab the join handle here
161-
factory.listen().await?;
162+
// TODO: this error isn't right, I just want that to compile
163+
factory
164+
.listen()
165+
.await
166+
.map_err(QueueRunnerError::SetupListener)?;
162167
debug!(?monitor, "workers registered");
163168

169+
let mut worker = self::new_queue::QueueWorker::new(state).await?;
170+
164171
// TODO: this is just spawning the task in the background, we probably actually
165172
// want to wrap that in a structure, and handle graceful shutdown correctly
166173
tokio::spawn(async move {
167-
if let Err(e) = self::new_queue::run(state).await {
174+
if let Err(e) = worker.run().await {
168175
tracing::error!(
169176
error = &e as &dyn std::error::Error,
170177
"Failed to run new queue"

crates/tasks/src/new_queue.rs

Lines changed: 173 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,79 +3,207 @@
33
// SPDX-License-Identifier: AGPL-3.0-only
44
// Please see LICENSE in the repository root for full details.
55

6-
use chrono::Duration;
7-
use mas_storage::{RepositoryAccess, RepositoryError};
6+
use chrono::{DateTime, Duration, Utc};
7+
use mas_storage::{queue::Worker, Clock, RepositoryAccess, RepositoryError};
8+
use mas_storage_pg::{DatabaseError, PgRepository};
9+
use rand::{distributions::Uniform, Rng};
10+
use rand_chacha::ChaChaRng;
11+
use sqlx::PgPool;
12+
use thiserror::Error;
813

914
use crate::State;
1015

11-
pub async fn run(state: State) -> Result<(), RepositoryError> {
12-
let span = tracing::info_span!("worker.init", worker.id = tracing::field::Empty);
13-
let guard = span.enter();
14-
let mut repo = state.repository().await?;
15-
let mut rng = state.rng();
16-
let clock = state.clock();
16+
#[derive(Debug, Error)]
17+
pub enum QueueRunnerError {
18+
#[error("Failed to setup listener")]
19+
SetupListener(#[source] sqlx::Error),
1720

18-
let worker = repo.queue_worker().register(&mut rng, &clock).await?;
19-
span.record("worker.id", tracing::field::display(worker.id));
20-
repo.save().await?;
21+
#[error("Failed to start transaction")]
22+
StartTransaction(#[source] sqlx::Error),
2123

22-
tracing::info!("Registered worker");
23-
drop(guard);
24+
#[error("Failed to commit transaction")]
25+
CommitTransaction(#[source] sqlx::Error),
2426

25-
let mut was_i_the_leader = false;
27+
#[error(transparent)]
28+
Repository(#[from] RepositoryError),
2629

27-
// Record when we last sent a heartbeat
28-
let mut last_heartbeat = clock.now();
30+
#[error(transparent)]
31+
Database(#[from] DatabaseError),
2932

30-
loop {
33+
#[error("Worker is not the leader")]
34+
NotLeader,
35+
}
36+
37+
const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900);
38+
const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100);
39+
40+
pub struct QueueWorker {
41+
rng: ChaChaRng,
42+
clock: Box<dyn Clock + Send>,
43+
pool: PgPool,
44+
registration: Worker,
45+
am_i_leader: bool,
46+
last_heartbeat: DateTime<Utc>,
47+
}
48+
49+
impl QueueWorker {
50+
#[tracing::instrument(
51+
name = "worker.init",
52+
skip_all,
53+
fields(worker.id)
54+
)]
55+
pub async fn new(state: State) -> Result<Self, QueueRunnerError> {
56+
let mut rng = state.rng();
57+
let clock = state.clock();
58+
let pool = state.pool().clone();
59+
60+
let txn = pool
61+
.begin()
62+
.await
63+
.map_err(QueueRunnerError::StartTransaction)?;
64+
let mut repo = PgRepository::from_conn(txn);
65+
66+
let registration = repo.queue_worker().register(&mut rng, &clock).await?;
67+
tracing::Span::current().record("worker.id", tracing::field::display(registration.id));
68+
repo.into_inner()
69+
.commit()
70+
.await
71+
.map_err(QueueRunnerError::CommitTransaction)?;
72+
73+
tracing::info!("Registered worker");
74+
let now = clock.now();
75+
76+
Ok(Self {
77+
rng,
78+
clock,
79+
pool,
80+
registration,
81+
am_i_leader: false,
82+
last_heartbeat: now,
83+
})
84+
}
85+
86+
pub async fn run(&mut self) -> Result<(), QueueRunnerError> {
87+
loop {
88+
self.run_loop().await?;
89+
}
90+
}
91+
92+
#[tracing::instrument(name = "worker.run_loop", skip_all, err)]
93+
async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
94+
self.wait_until_wakeup().await?;
95+
self.tick().await?;
96+
97+
if self.am_i_leader {
98+
self.perform_leader_duties().await?;
99+
}
100+
101+
Ok(())
102+
}
103+
104+
#[tracing::instrument(name = "worker.wait_until_wakeup", skip_all, err)]
105+
async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> {
31106
// This is to make sure we wake up every second to do the maintenance tasks
32-
// Later we might wait on other events, like a PG notification
33-
let wakeup_sleep = tokio::time::sleep(std::time::Duration::from_secs(1));
34-
wakeup_sleep.await;
107+
// We add a little bit of random jitter to the duration, so that we don't get
108+
// fully synced workers waking up at the same time after each notification
109+
let sleep_duration = self
110+
.rng
111+
.sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION));
112+
tokio::time::sleep(sleep_duration).await;
113+
tracing::debug!("Woke up from sleep");
114+
115+
Ok(())
116+
}
35117

36-
let span = tracing::info_span!("worker.tick", %worker.id);
37-
let _guard = span.enter();
118+
fn set_new_leader_state(&mut self, state: bool) {
119+
// Do nothing if we were already on that state
120+
if state == self.am_i_leader {
121+
return;
122+
}
38123

124+
// If we flipped state, log it
125+
self.am_i_leader = state;
126+
if self.am_i_leader {
127+
tracing::info!("I'm the leader now");
128+
} else {
129+
tracing::warn!("I am no longer the leader");
130+
}
131+
}
132+
133+
#[tracing::instrument(
134+
name = "worker.tick",
135+
skip_all,
136+
fields(worker.id = %self.registration.id),
137+
err,
138+
)]
139+
async fn tick(&mut self) -> Result<(), QueueRunnerError> {
39140
tracing::debug!("Tick");
40-
let now = clock.now();
41-
let mut repo = state.repository().await?;
141+
let now = self.clock.now();
142+
143+
let txn = self
144+
.pool
145+
.begin()
146+
.await
147+
.map_err(QueueRunnerError::StartTransaction)?;
148+
let mut repo = PgRepository::from_conn(txn);
42149

43150
// We send a heartbeat every minute, to avoid writing to the database too often
44151
// on a logged table
45-
if now - last_heartbeat >= chrono::Duration::minutes(1) {
152+
if now - self.last_heartbeat >= chrono::Duration::minutes(1) {
46153
tracing::info!("Sending heartbeat");
47-
repo.queue_worker().heartbeat(&clock, &worker).await?;
48-
last_heartbeat = now;
154+
repo.queue_worker()
155+
.heartbeat(&self.clock, &self.registration)
156+
.await?;
157+
self.last_heartbeat = now;
49158
}
50159

51160
// Remove any dead worker leader leases
52161
repo.queue_worker()
53-
.remove_leader_lease_if_expired(&clock)
162+
.remove_leader_lease_if_expired(&self.clock)
54163
.await?;
55164

56165
// Try to become (or stay) the leader
57-
let am_i_the_leader = repo
166+
let leader = repo
58167
.queue_worker()
59-
.try_get_leader_lease(&clock, &worker)
168+
.try_get_leader_lease(&self.clock, &self.registration)
60169
.await?;
61170

62-
// Log any changes in leadership
63-
if !was_i_the_leader && am_i_the_leader {
64-
tracing::info!("I'm the leader now");
65-
} else if was_i_the_leader && !am_i_the_leader {
66-
tracing::warn!("I am no longer the leader");
67-
}
68-
was_i_the_leader = am_i_the_leader;
171+
repo.into_inner()
172+
.commit()
173+
.await
174+
.map_err(QueueRunnerError::CommitTransaction)?;
69175

70-
// The leader does all the maintenance work
71-
if am_i_the_leader {
72-
// We also check if the worker is dead, and if so, we shutdown all the dead
73-
// workers that haven't checked in the last two minutes
74-
repo.queue_worker()
75-
.shutdown_dead_workers(&clock, Duration::minutes(2))
76-
.await?;
176+
// Save the new leader state
177+
self.set_new_leader_state(leader);
178+
179+
Ok(())
180+
}
181+
182+
#[tracing::instrument(name = "worker.perform_leader_duties", skip_all, err)]
183+
async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> {
184+
// This should have been checked by the caller, but better safe than sorry
185+
if !self.am_i_leader {
186+
return Err(QueueRunnerError::NotLeader);
77187
}
78188

79-
repo.save().await?;
189+
let txn = self
190+
.pool
191+
.begin()
192+
.await
193+
.map_err(QueueRunnerError::StartTransaction)?;
194+
let mut repo = PgRepository::from_conn(txn);
195+
196+
// We also check if the worker is dead, and if so, we shutdown all the dead
197+
// workers that haven't checked in the last two minutes
198+
repo.queue_worker()
199+
.shutdown_dead_workers(&self.clock, Duration::minutes(2))
200+
.await?;
201+
202+
repo.into_inner()
203+
.commit()
204+
.await
205+
.map_err(QueueRunnerError::CommitTransaction)?;
206+
207+
Ok(())
80208
}
81209
}

0 commit comments

Comments
 (0)