From a0aeb2d4efeb3c9c345619ff6c0b0e22909a7f17 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 7 Oct 2024 12:07:09 +0200 Subject: [PATCH 01/17] New job queue: worker registration and leader election --- ...0f3f086e4e62c6de9d6864a6a11a2470ebe62.json | 15 ++ ...0dc74505b22c681322bd99b62c2a540c6cd35.json | 15 ++ ...5b531b9873f4139eadcbf1450e726b9a27379.json | 15 ++ ...d356a4ed86fd33400066e422545ffc55f9aa9.json | 16 ++ ...327d03b29fe413d57cce21c67b6d539f59e7d.json | 15 ++ ...84163ae08e16a8e07c6ecbca4fd8cb10da8a5.json | 14 ++ .../20241004075132_queue_worker.sql | 37 +++ crates/storage-pg/src/lib.rs | 1 + crates/storage-pg/src/queue/mod.rs | 8 + crates/storage-pg/src/queue/worker.rs | 237 ++++++++++++++++++ crates/storage-pg/src/repository.rs | 7 + crates/storage/src/lib.rs | 1 + crates/storage/src/queue/mod.rs | 10 + crates/storage/src/queue/worker.rs | 130 ++++++++++ crates/storage/src/repository.rs | 17 ++ crates/tasks/src/lib.rs | 24 +- crates/tasks/src/new_queue.rs | 81 ++++++ 17 files changed, 639 insertions(+), 4 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-12c4577701416a9dc23708c46700f3f086e4e62c6de9d6864a6a11a2470ebe62.json create mode 100644 crates/storage-pg/.sqlx/query-5f2199865fae3a969bb37429dd70dc74505b22c681322bd99b62c2a540c6cd35.json create mode 100644 crates/storage-pg/.sqlx/query-6bd38759f569fcf972924d12f565b531b9873f4139eadcbf1450e726b9a27379.json create mode 100644 crates/storage-pg/.sqlx/query-8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9.json create mode 100644 crates/storage-pg/.sqlx/query-966ca0f7eebd2896c007b2fd6e9327d03b29fe413d57cce21c67b6d539f59e7d.json create mode 100644 crates/storage-pg/.sqlx/query-ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5.json create mode 100644 crates/storage-pg/migrations/20241004075132_queue_worker.sql create mode 100644 crates/storage-pg/src/queue/mod.rs create mode 100644 crates/storage-pg/src/queue/worker.rs create mode 100644 crates/storage/src/queue/mod.rs create mode 100644 crates/storage/src/queue/worker.rs create mode 100644 crates/tasks/src/new_queue.rs diff --git a/crates/storage-pg/.sqlx/query-12c4577701416a9dc23708c46700f3f086e4e62c6de9d6864a6a11a2470ebe62.json b/crates/storage-pg/.sqlx/query-12c4577701416a9dc23708c46700f3f086e4e62c6de9d6864a6a11a2470ebe62.json new file mode 100644 index 000000000..dce1983fe --- /dev/null +++ b/crates/storage-pg/.sqlx/query-12c4577701416a9dc23708c46700f3f086e4e62c6de9d6864a6a11a2470ebe62.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_workers (queue_worker_id, registered_at, last_seen_at)\n VALUES ($1, $2, $2)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "12c4577701416a9dc23708c46700f3f086e4e62c6de9d6864a6a11a2470ebe62" +} diff --git a/crates/storage-pg/.sqlx/query-5f2199865fae3a969bb37429dd70dc74505b22c681322bd99b62c2a540c6cd35.json b/crates/storage-pg/.sqlx/query-5f2199865fae3a969bb37429dd70dc74505b22c681322bd99b62c2a540c6cd35.json new file mode 100644 index 000000000..364a1c6b6 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-5f2199865fae3a969bb37429dd70dc74505b22c681322bd99b62c2a540c6cd35.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_workers\n SET shutdown_at = $2\n WHERE queue_worker_id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "5f2199865fae3a969bb37429dd70dc74505b22c681322bd99b62c2a540c6cd35" +} diff --git a/crates/storage-pg/.sqlx/query-6bd38759f569fcf972924d12f565b531b9873f4139eadcbf1450e726b9a27379.json b/crates/storage-pg/.sqlx/query-6bd38759f569fcf972924d12f565b531b9873f4139eadcbf1450e726b9a27379.json new file mode 100644 index 000000000..4898fc432 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-6bd38759f569fcf972924d12f565b531b9873f4139eadcbf1450e726b9a27379.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_workers\n SET shutdown_at = $1\n WHERE shutdown_at IS NULL\n AND last_seen_at < $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "6bd38759f569fcf972924d12f565b531b9873f4139eadcbf1450e726b9a27379" +} diff --git a/crates/storage-pg/.sqlx/query-8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9.json b/crates/storage-pg/.sqlx/query-8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9.json new file mode 100644 index 000000000..9195a9d4d --- /dev/null +++ b/crates/storage-pg/.sqlx/query-8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id)\n VALUES ($1, $2, $3)\n ON CONFLICT (active)\n DO UPDATE SET expires_at = EXCLUDED.expires_at\n WHERE queue_leader.queue_worker_id = $3\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9" +} diff --git a/crates/storage-pg/.sqlx/query-966ca0f7eebd2896c007b2fd6e9327d03b29fe413d57cce21c67b6d539f59e7d.json b/crates/storage-pg/.sqlx/query-966ca0f7eebd2896c007b2fd6e9327d03b29fe413d57cce21c67b6d539f59e7d.json new file mode 100644 index 000000000..3e1fb3580 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-966ca0f7eebd2896c007b2fd6e9327d03b29fe413d57cce21c67b6d539f59e7d.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_workers\n SET last_seen_at = $2\n WHERE queue_worker_id = $1 AND shutdown_at IS NULL\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "966ca0f7eebd2896c007b2fd6e9327d03b29fe413d57cce21c67b6d539f59e7d" +} diff --git a/crates/storage-pg/.sqlx/query-ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5.json b/crates/storage-pg/.sqlx/query-ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5.json new file mode 100644 index 000000000..af6213a8a --- /dev/null +++ b/crates/storage-pg/.sqlx/query-ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM queue_leader\n WHERE expires_at < $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5" +} diff --git a/crates/storage-pg/migrations/20241004075132_queue_worker.sql b/crates/storage-pg/migrations/20241004075132_queue_worker.sql new file mode 100644 index 000000000..07b49d22d --- /dev/null +++ b/crates/storage-pg/migrations/20241004075132_queue_worker.sql @@ -0,0 +1,37 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- This table stores informations about worker, mostly to track their health +CREATE TABLE queue_workers ( + queue_worker_id UUID NOT NULL PRIMARY KEY, + + -- When the worker was registered + registered_at TIMESTAMP WITH TIME ZONE NOT NULL, + + -- When the worker was last seen + last_seen_at TIMESTAMP WITH TIME ZONE NOT NULL, + + -- When the worker was shut down + shutdown_at TIMESTAMP WITH TIME ZONE +); + +-- This single-row table stores the leader of the queue +-- The leader is responsible for running maintenance tasks +CREATE UNLOGGED TABLE queue_leader ( + -- This makes the row unique + active BOOLEAN NOT NULL DEFAULT TRUE UNIQUE, + + -- When the leader was elected + elected_at TIMESTAMP WITH TIME ZONE NOT NULL, + + -- Until when the lease is valid + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + + -- The worker ID of the leader + queue_worker_id UUID NOT NULL REFERENCES queue_workers (queue_worker_id), + + -- This, combined with the unique constraint, makes sure we only ever have a single row + CONSTRAINT queue_leader_active CHECK (active IS TRUE) +); diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index aa7fadd55..e16303278 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -166,6 +166,7 @@ pub mod app_session; pub mod compat; pub mod job; pub mod oauth2; +pub mod queue; pub mod upstream_oauth2; pub mod user; diff --git a/crates/storage-pg/src/queue/mod.rs b/crates/storage-pg/src/queue/mod.rs new file mode 100644 index 000000000..b6ba8295e --- /dev/null +++ b/crates/storage-pg/src/queue/mod.rs @@ -0,0 +1,8 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! A module containing the PostgreSQL implementation of the job queue + +pub mod worker; diff --git a/crates/storage-pg/src/queue/worker.rs b/crates/storage-pg/src/queue/worker.rs new file mode 100644 index 000000000..2aaacc64b --- /dev/null +++ b/crates/storage-pg/src/queue/worker.rs @@ -0,0 +1,237 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! A module containing the PostgreSQL implementation of the +//! [`QueueWorkerRepository`]. + +use async_trait::async_trait; +use chrono::Duration; +use mas_storage::{ + queue::{QueueWorkerRepository, Worker}, + Clock, +}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{DatabaseError, ExecuteExt}; + +/// An implementation of [`QueueWorkerRepository`] for a PostgreSQL connection. +pub struct PgQueueWorkerRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgQueueWorkerRepository<'c> { + /// Create a new [`PgQueueWorkerRepository`] from an active PostgreSQL + /// connection. + #[must_use] + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[async_trait] +impl QueueWorkerRepository for PgQueueWorkerRepository<'_> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.queue_worker.register", + skip_all, + fields( + worker.id, + db.query.text, + ), + err, + )] + async fn register( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + ) -> Result { + let now = clock.now(); + let worker_id = Ulid::from_datetime_with_source(now.into(), rng); + tracing::Span::current().record("worker.id", tracing::field::display(worker_id)); + + sqlx::query!( + r#" + INSERT INTO queue_workers (queue_worker_id, registered_at, last_seen_at) + VALUES ($1, $2, $2) + "#, + Uuid::from(worker_id), + now, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Worker { id: worker_id }) + } + + #[tracing::instrument( + name = "db.queue_worker.heartbeat", + skip_all, + fields( + %worker.id, + db.query.text, + ), + err, + )] + async fn heartbeat( + &mut self, + clock: &dyn Clock, + worker: Worker, + ) -> Result { + let now = clock.now(); + let res = sqlx::query!( + r#" + UPDATE queue_workers + SET last_seen_at = $2 + WHERE queue_worker_id = $1 AND shutdown_at IS NULL + "#, + Uuid::from(worker.id), + now, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + // If no row was updated, the worker was shutdown so we return an error + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(worker) + } + + #[tracing::instrument( + name = "db.queue_worker.shutdown", + skip_all, + fields( + %worker.id, + db.query.text, + ), + err, + )] + async fn shutdown(&mut self, clock: &dyn Clock, worker: Worker) -> Result<(), Self::Error> { + let now = clock.now(); + let res = sqlx::query!( + r#" + UPDATE queue_workers + SET shutdown_at = $2 + WHERE queue_worker_id = $1 + "#, + Uuid::from(worker.id), + now, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.queue_worker.shutdown_dead_workers", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn shutdown_dead_workers( + &mut self, + clock: &dyn Clock, + threshold: Duration, + ) -> Result<(), Self::Error> { + let now = clock.now(); + sqlx::query!( + r#" + UPDATE queue_workers + SET shutdown_at = $1 + WHERE shutdown_at IS NULL + AND last_seen_at < $2 + "#, + now, + now - threshold, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.queue_worker.remove_leader_lease_if_expired", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn remove_leader_lease_if_expired( + &mut self, + clock: &dyn Clock, + ) -> Result<(), Self::Error> { + let now = clock.now(); + sqlx::query!( + r#" + DELETE FROM queue_leader + WHERE expires_at < $1 + "#, + now, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.queue_worker.try_get_leader_lease", + skip_all, + fields( + %worker.id, + db.query.text, + ), + err, + )] + async fn try_get_leader_lease( + &mut self, + clock: &dyn Clock, + worker: &Worker, + ) -> Result { + let now = clock.now(); + let ttl = Duration::seconds(5); + // The queue_leader table is meant to only have a single row, which conflicts on + // the `active` column + + // If there is a conflict, we update the `expires_at` column ONLY IF the current + // leader is ourselves. + let res = sqlx::query!( + r#" + INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id) + VALUES ($1, $2, $3) + ON CONFLICT (active) + DO UPDATE SET expires_at = EXCLUDED.expires_at + WHERE queue_leader.queue_worker_id = $3 + "#, + now, + now + ttl, + Uuid::from(worker.id) + ) + .traced() + .execute(&mut *self.conn) + .await?; + + // We can then detect whether we are the leader or not by checking how many rows + // were affected by the upsert + let am_i_the_leader = res.rows_affected() == 1; + + Ok(am_i_the_leader) + } +} diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 284f5e2dc..99580467c 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -40,6 +40,7 @@ use crate::{ PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, }, + queue::worker::PgQueueWorkerRepository, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, @@ -263,4 +264,10 @@ where fn job<'c>(&'c mut self) -> Box + 'c> { Box::new(PgJobRepository::new(self.conn.as_mut())) } + + fn queue_worker<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgQueueWorkerRepository::new(self.conn.as_mut())) + } } diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index f2f699ba6..30dc553de 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -120,6 +120,7 @@ pub mod app_session; pub mod compat; pub mod job; pub mod oauth2; +pub mod queue; pub mod upstream_oauth2; pub mod user; diff --git a/crates/storage/src/queue/mod.rs b/crates/storage/src/queue/mod.rs new file mode 100644 index 000000000..4ca97ec5e --- /dev/null +++ b/crates/storage/src/queue/mod.rs @@ -0,0 +1,10 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! A module containing repositories for the job queue + +mod worker; + +pub use self::worker::{QueueWorkerRepository, Worker}; diff --git a/crates/storage/src/queue/worker.rs b/crates/storage/src/queue/worker.rs new file mode 100644 index 000000000..dfb9699e6 --- /dev/null +++ b/crates/storage/src/queue/worker.rs @@ -0,0 +1,130 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! Repository to interact with workers in the job queue + +use async_trait::async_trait; +use chrono::Duration; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{repository_impl, Clock}; + +/// A worker is an entity which can execute jobs. +pub struct Worker { + /// The ID of the worker. + pub id: Ulid, +} + +/// A [`QueueWorkerRepository`] is used to schedule jobs to be executed by a +/// worker. +#[async_trait] +pub trait QueueWorkerRepository: Send + Sync { + /// The error type returned by the repository. + type Error; + + /// Register a new worker. + /// + /// Returns a reference to the worker. + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn register( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + ) -> Result; + + /// Send a heartbeat for the given worker. + /// + /// Returns the updated worker. + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails or if the worker was + /// shutdown. + async fn heartbeat(&mut self, clock: &dyn Clock, worker: Worker) + -> Result; + + /// Mark the given worker as shutdown. + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn shutdown(&mut self, clock: &dyn Clock, worker: Worker) -> Result<(), Self::Error>; + + /// Find dead workers and shut them down. + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn shutdown_dead_workers( + &mut self, + clock: &dyn Clock, + threshold: Duration, + ) -> Result<(), Self::Error>; + + /// Remove the leader lease if it is expired, sending a notification to + /// trigger a new leader election. + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn remove_leader_lease_if_expired( + &mut self, + clock: &dyn Clock, + ) -> Result<(), Self::Error>; + + /// Try to get the leader lease, renewing it if we already have it + /// + /// Returns `true` if we got the leader lease, `false` if we didn't + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn try_get_leader_lease( + &mut self, + clock: &dyn Clock, + worker: &Worker, + ) -> Result; +} + +repository_impl!(QueueWorkerRepository: + async fn register( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + ) -> Result; + + async fn heartbeat( + &mut self, + clock: &dyn Clock, + worker: Worker, + ) -> Result; + + async fn shutdown( + &mut self, + clock: &dyn Clock, + worker: Worker, + ) -> Result<(), Self::Error>; + + async fn shutdown_dead_workers( + &mut self, + clock: &dyn Clock, + threshold: Duration, + ) -> Result<(), Self::Error>; + + async fn remove_leader_lease_if_expired( + &mut self, + clock: &dyn Clock, + ) -> Result<(), Self::Error>; + + async fn try_get_leader_lease( + &mut self, + clock: &dyn Clock, + worker: &Worker, + ) -> Result; +); diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index a78d51d1d..55d19d281 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -18,6 +18,7 @@ use crate::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, + queue::QueueWorkerRepository, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, @@ -191,6 +192,9 @@ pub trait RepositoryAccess: Send { /// Get a [`JobRepository`] fn job<'c>(&'c mut self) -> Box + 'c>; + + /// Get a [`QueueWorkerRepository`] + fn queue_worker<'c>(&'c mut self) -> Box + 'c>; } /// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and @@ -211,6 +215,7 @@ mod impls { OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, + queue::QueueWorkerRepository, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, @@ -405,6 +410,12 @@ mod impls { fn job<'c>(&'c mut self) -> Box + 'c> { Box::new(MapErr::new(self.inner.job(), &mut self.mapper)) } + + fn queue_worker<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.queue_worker(), &mut self.mapper)) + } } impl RepositoryAccess for Box { @@ -527,5 +538,11 @@ mod impls { fn job<'c>(&'c mut self) -> Box + 'c> { (**self).job() } + + fn queue_worker<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).queue_worker() + } } } diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 8d012bdae..52e9683cc 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -10,8 +10,8 @@ use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monito use mas_email::Mailer; use mas_matrix::HomeserverConnection; use mas_router::UrlBuilder; -use mas_storage::{BoxClock, BoxRepository, SystemClock}; -use mas_storage_pg::{DatabaseError, PgRepository}; +use mas_storage::{BoxClock, BoxRepository, RepositoryError, SystemClock}; +use mas_storage_pg::PgRepository; use rand::SeedableRng; use sqlx::{Pool, Postgres}; use tracing::debug; @@ -21,6 +21,7 @@ use crate::storage::PostgresStorageFactory; mod database; mod email; mod matrix; +mod new_queue; mod recovery; mod storage; mod user; @@ -74,8 +75,11 @@ impl State { rand_chacha::ChaChaRng::from_rng(rand::thread_rng()).expect("failed to seed rng") } - pub async fn repository(&self) -> Result { - let repo = PgRepository::from_pool(self.pool()).await?.boxed(); + pub async fn repository(&self) -> Result { + let repo = PgRepository::from_pool(self.pool()) + .await + .map_err(RepositoryError::from_error)? + .boxed(); Ok(repo) } @@ -156,5 +160,17 @@ pub async fn init( // TODO: we might want to grab the join handle here factory.listen().await?; debug!(?monitor, "workers registered"); + + // TODO: this is just spawning the task in the background, we probably actually + // want to wrap that in a structure, and handle graceful shutdown correctly + tokio::spawn(async move { + if let Err(e) = self::new_queue::run(state).await { + tracing::error!( + error = &e as &dyn std::error::Error, + "Failed to run new queue" + ); + } + }); + Ok(monitor) } diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs new file mode 100644 index 000000000..eabf17aa6 --- /dev/null +++ b/crates/tasks/src/new_queue.rs @@ -0,0 +1,81 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use chrono::Duration; +use mas_storage::{RepositoryAccess, RepositoryError}; + +use crate::State; + +pub async fn run(state: State) -> Result<(), RepositoryError> { + let span = tracing::info_span!("worker.init", worker.id = tracing::field::Empty); + let guard = span.enter(); + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let clock = state.clock(); + + let mut worker = repo.queue_worker().register(&mut rng, &clock).await?; + span.record("worker.id", tracing::field::display(worker.id)); + repo.save().await?; + + tracing::info!("Registered worker"); + drop(guard); + + let mut was_i_the_leader = false; + + // Record when we last sent a heartbeat + let mut last_heartbeat = clock.now(); + + loop { + // This is to make sure we wake up every second to do the maintenance tasks + // Later we might wait on other events, like a PG notification + let wakeup_sleep = tokio::time::sleep(std::time::Duration::from_secs(1)); + wakeup_sleep.await; + + let span = tracing::info_span!("worker.tick", %worker.id); + let _guard = span.enter(); + + tracing::debug!("Tick"); + let now = clock.now(); + let mut repo = state.repository().await?; + + // We send a heartbeat every minute, to avoid writing to the database too often + // on a logged table + if now - last_heartbeat >= chrono::Duration::minutes(1) { + tracing::info!("Sending heartbeat"); + worker = repo.queue_worker().heartbeat(&clock, worker).await?; + last_heartbeat = now; + } + + // Remove any dead worker leader leases + repo.queue_worker() + .remove_leader_lease_if_expired(&clock) + .await?; + + // Try to become (or stay) the leader + let am_i_the_leader = repo + .queue_worker() + .try_get_leader_lease(&clock, &worker) + .await?; + + // Log any changes in leadership + if !was_i_the_leader && am_i_the_leader { + tracing::info!("I'm the leader now"); + } else if was_i_the_leader && !am_i_the_leader { + tracing::warn!("I am no longer the leader"); + } + was_i_the_leader = am_i_the_leader; + + // The leader does all the maintenance work + if am_i_the_leader { + // We also check if the worker is dead, and if so, we shutdown all the dead + // workers that haven't checked in the last two minutes + repo.queue_worker() + .shutdown_dead_workers(&clock, Duration::minutes(2)) + .await?; + } + + repo.save().await?; + } +} From 63632645196562671514b5a07537f2330f65fa6b Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 9 Oct 2024 10:28:59 +0200 Subject: [PATCH 02/17] Make the worker heartbeat take a worker reference --- crates/storage-pg/src/queue/worker.rs | 8 ++------ crates/storage/src/queue/worker.rs | 9 +++------ crates/tasks/src/new_queue.rs | 4 ++-- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/crates/storage-pg/src/queue/worker.rs b/crates/storage-pg/src/queue/worker.rs index 2aaacc64b..5fa784cb1 100644 --- a/crates/storage-pg/src/queue/worker.rs +++ b/crates/storage-pg/src/queue/worker.rs @@ -79,11 +79,7 @@ impl QueueWorkerRepository for PgQueueWorkerRepository<'_> { ), err, )] - async fn heartbeat( - &mut self, - clock: &dyn Clock, - worker: Worker, - ) -> Result { + async fn heartbeat(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> { let now = clock.now(); let res = sqlx::query!( r#" @@ -101,7 +97,7 @@ impl QueueWorkerRepository for PgQueueWorkerRepository<'_> { // If no row was updated, the worker was shutdown so we return an error DatabaseError::ensure_affected_rows(&res, 1)?; - Ok(worker) + Ok(()) } #[tracing::instrument( diff --git a/crates/storage/src/queue/worker.rs b/crates/storage/src/queue/worker.rs index dfb9699e6..19ceead88 100644 --- a/crates/storage/src/queue/worker.rs +++ b/crates/storage/src/queue/worker.rs @@ -40,14 +40,11 @@ pub trait QueueWorkerRepository: Send + Sync { /// Send a heartbeat for the given worker. /// - /// Returns the updated worker. - /// /// # Errors /// /// Returns an error if the underlying repository fails or if the worker was /// shutdown. - async fn heartbeat(&mut self, clock: &dyn Clock, worker: Worker) - -> Result; + async fn heartbeat(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error>; /// Mark the given worker as shutdown. /// @@ -102,8 +99,8 @@ repository_impl!(QueueWorkerRepository: async fn heartbeat( &mut self, clock: &dyn Clock, - worker: Worker, - ) -> Result; + worker: &Worker, + ) -> Result<(), Self::Error>; async fn shutdown( &mut self, diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index eabf17aa6..4a8058b59 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -15,7 +15,7 @@ pub async fn run(state: State) -> Result<(), RepositoryError> { let mut rng = state.rng(); let clock = state.clock(); - let mut worker = repo.queue_worker().register(&mut rng, &clock).await?; + let worker = repo.queue_worker().register(&mut rng, &clock).await?; span.record("worker.id", tracing::field::display(worker.id)); repo.save().await?; @@ -44,7 +44,7 @@ pub async fn run(state: State) -> Result<(), RepositoryError> { // on a logged table if now - last_heartbeat >= chrono::Duration::minutes(1) { tracing::info!("Sending heartbeat"); - worker = repo.queue_worker().heartbeat(&clock, worker).await?; + repo.queue_worker().heartbeat(&clock, &worker).await?; last_heartbeat = now; } From 8c1a87b0df513071fbacaadfd66af24dd95dcd1f Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 14 Oct 2024 10:51:14 +0200 Subject: [PATCH 03/17] Move the worker logic in a struct --- crates/tasks/src/lib.rs | 13 +- crates/tasks/src/new_queue.rs | 218 +++++++++++++++++++++++++++------- 2 files changed, 183 insertions(+), 48 deletions(-) diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 52e9683cc..0db6b6b81 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -12,6 +12,7 @@ use mas_matrix::HomeserverConnection; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRepository, RepositoryError, SystemClock}; use mas_storage_pg::PgRepository; +use new_queue::QueueRunnerError; use rand::SeedableRng; use sqlx::{Pool, Postgres}; use tracing::debug; @@ -142,7 +143,7 @@ pub async fn init( mailer: &Mailer, homeserver: impl HomeserverConnection + 'static, url_builder: UrlBuilder, -) -> Result, sqlx::Error> { +) -> Result, QueueRunnerError> { let state = State::new( pool.clone(), SystemClock::default(), @@ -158,13 +159,19 @@ pub async fn init( let monitor = self::user::register(name, monitor, &state, &factory); let monitor = self::recovery::register(name, monitor, &state, &factory); // TODO: we might want to grab the join handle here - factory.listen().await?; + // TODO: this error isn't right, I just want that to compile + factory + .listen() + .await + .map_err(QueueRunnerError::SetupListener)?; debug!(?monitor, "workers registered"); + let mut worker = self::new_queue::QueueWorker::new(state).await?; + // TODO: this is just spawning the task in the background, we probably actually // want to wrap that in a structure, and handle graceful shutdown correctly tokio::spawn(async move { - if let Err(e) = self::new_queue::run(state).await { + if let Err(e) = worker.run().await { tracing::error!( error = &e as &dyn std::error::Error, "Failed to run new queue" diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index 4a8058b59..c3596c0e9 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -3,79 +3,207 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use chrono::Duration; -use mas_storage::{RepositoryAccess, RepositoryError}; +use chrono::{DateTime, Duration, Utc}; +use mas_storage::{queue::Worker, Clock, RepositoryAccess, RepositoryError}; +use mas_storage_pg::{DatabaseError, PgRepository}; +use rand::{distributions::Uniform, Rng}; +use rand_chacha::ChaChaRng; +use sqlx::PgPool; +use thiserror::Error; use crate::State; -pub async fn run(state: State) -> Result<(), RepositoryError> { - let span = tracing::info_span!("worker.init", worker.id = tracing::field::Empty); - let guard = span.enter(); - let mut repo = state.repository().await?; - let mut rng = state.rng(); - let clock = state.clock(); +#[derive(Debug, Error)] +pub enum QueueRunnerError { + #[error("Failed to setup listener")] + SetupListener(#[source] sqlx::Error), - let worker = repo.queue_worker().register(&mut rng, &clock).await?; - span.record("worker.id", tracing::field::display(worker.id)); - repo.save().await?; + #[error("Failed to start transaction")] + StartTransaction(#[source] sqlx::Error), - tracing::info!("Registered worker"); - drop(guard); + #[error("Failed to commit transaction")] + CommitTransaction(#[source] sqlx::Error), - let mut was_i_the_leader = false; + #[error(transparent)] + Repository(#[from] RepositoryError), - // Record when we last sent a heartbeat - let mut last_heartbeat = clock.now(); + #[error(transparent)] + Database(#[from] DatabaseError), - loop { + #[error("Worker is not the leader")] + NotLeader, +} + +const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900); +const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100); + +pub struct QueueWorker { + rng: ChaChaRng, + clock: Box, + pool: PgPool, + registration: Worker, + am_i_leader: bool, + last_heartbeat: DateTime, +} + +impl QueueWorker { + #[tracing::instrument( + name = "worker.init", + skip_all, + fields(worker.id) + )] + pub async fn new(state: State) -> Result { + let mut rng = state.rng(); + let clock = state.clock(); + let pool = state.pool().clone(); + + let txn = pool + .begin() + .await + .map_err(QueueRunnerError::StartTransaction)?; + let mut repo = PgRepository::from_conn(txn); + + let registration = repo.queue_worker().register(&mut rng, &clock).await?; + tracing::Span::current().record("worker.id", tracing::field::display(registration.id)); + repo.into_inner() + .commit() + .await + .map_err(QueueRunnerError::CommitTransaction)?; + + tracing::info!("Registered worker"); + let now = clock.now(); + + Ok(Self { + rng, + clock, + pool, + registration, + am_i_leader: false, + last_heartbeat: now, + }) + } + + pub async fn run(&mut self) -> Result<(), QueueRunnerError> { + loop { + self.run_loop().await?; + } + } + + #[tracing::instrument(name = "worker.run_loop", skip_all, err)] + async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { + self.wait_until_wakeup().await?; + self.tick().await?; + + if self.am_i_leader { + self.perform_leader_duties().await?; + } + + Ok(()) + } + + #[tracing::instrument(name = "worker.wait_until_wakeup", skip_all, err)] + async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> { // This is to make sure we wake up every second to do the maintenance tasks - // Later we might wait on other events, like a PG notification - let wakeup_sleep = tokio::time::sleep(std::time::Duration::from_secs(1)); - wakeup_sleep.await; + // We add a little bit of random jitter to the duration, so that we don't get + // fully synced workers waking up at the same time after each notification + let sleep_duration = self + .rng + .sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION)); + tokio::time::sleep(sleep_duration).await; + tracing::debug!("Woke up from sleep"); + + Ok(()) + } - let span = tracing::info_span!("worker.tick", %worker.id); - let _guard = span.enter(); + fn set_new_leader_state(&mut self, state: bool) { + // Do nothing if we were already on that state + if state == self.am_i_leader { + return; + } + // If we flipped state, log it + self.am_i_leader = state; + if self.am_i_leader { + tracing::info!("I'm the leader now"); + } else { + tracing::warn!("I am no longer the leader"); + } + } + + #[tracing::instrument( + name = "worker.tick", + skip_all, + fields(worker.id = %self.registration.id), + err, + )] + async fn tick(&mut self) -> Result<(), QueueRunnerError> { tracing::debug!("Tick"); - let now = clock.now(); - let mut repo = state.repository().await?; + let now = self.clock.now(); + + let txn = self + .pool + .begin() + .await + .map_err(QueueRunnerError::StartTransaction)?; + let mut repo = PgRepository::from_conn(txn); // We send a heartbeat every minute, to avoid writing to the database too often // on a logged table - if now - last_heartbeat >= chrono::Duration::minutes(1) { + if now - self.last_heartbeat >= chrono::Duration::minutes(1) { tracing::info!("Sending heartbeat"); - repo.queue_worker().heartbeat(&clock, &worker).await?; - last_heartbeat = now; + repo.queue_worker() + .heartbeat(&self.clock, &self.registration) + .await?; + self.last_heartbeat = now; } // Remove any dead worker leader leases repo.queue_worker() - .remove_leader_lease_if_expired(&clock) + .remove_leader_lease_if_expired(&self.clock) .await?; // Try to become (or stay) the leader - let am_i_the_leader = repo + let leader = repo .queue_worker() - .try_get_leader_lease(&clock, &worker) + .try_get_leader_lease(&self.clock, &self.registration) .await?; - // Log any changes in leadership - if !was_i_the_leader && am_i_the_leader { - tracing::info!("I'm the leader now"); - } else if was_i_the_leader && !am_i_the_leader { - tracing::warn!("I am no longer the leader"); - } - was_i_the_leader = am_i_the_leader; + repo.into_inner() + .commit() + .await + .map_err(QueueRunnerError::CommitTransaction)?; - // The leader does all the maintenance work - if am_i_the_leader { - // We also check if the worker is dead, and if so, we shutdown all the dead - // workers that haven't checked in the last two minutes - repo.queue_worker() - .shutdown_dead_workers(&clock, Duration::minutes(2)) - .await?; + // Save the new leader state + self.set_new_leader_state(leader); + + Ok(()) + } + + #[tracing::instrument(name = "worker.perform_leader_duties", skip_all, err)] + async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> { + // This should have been checked by the caller, but better safe than sorry + if !self.am_i_leader { + return Err(QueueRunnerError::NotLeader); } - repo.save().await?; + let txn = self + .pool + .begin() + .await + .map_err(QueueRunnerError::StartTransaction)?; + let mut repo = PgRepository::from_conn(txn); + + // We also check if the worker is dead, and if so, we shutdown all the dead + // workers that haven't checked in the last two minutes + repo.queue_worker() + .shutdown_dead_workers(&self.clock, Duration::minutes(2)) + .await?; + + repo.into_inner() + .commit() + .await + .map_err(QueueRunnerError::CommitTransaction)?; + + Ok(()) } } From d3e51a126e02709cf05f8c426fc15a068ef5071d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 8 Oct 2024 17:02:18 +0200 Subject: [PATCH 04/17] TEMP: use patched sqlx We would like to use the underlying connection from the PgListener, which was added in a patch, but not yet merged or released. --- Cargo.lock | 51 ++++++++++----------------------------------------- Cargo.toml | 7 +++++++ deny.toml | 2 +- 3 files changed, 18 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0c7db2a09..a61e07919 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3084,7 +3084,6 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" dependencies = [ - "cc", "pkg-config", "vcpkg", ] @@ -5882,21 +5881,10 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b9b39299b249ad65f3b7e96443bad61c02ca5cd3589f46cb6d610a0fd6c0d6a" -[[package]] -name = "sqlformat" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" -dependencies = [ - "nom", - "unicode_categories", -] - [[package]] name = "sqlx" version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93334716a037193fac19df402f8571269c84a00852f6a7066b5d2616dcd64d3e" +source = "git+https://github.com/launchbadge/sqlx.git?branch=main#42ce24dab87aad98f041cafb35cf9a7d5b2b09a7" dependencies = [ "sqlx-core", "sqlx-macros", @@ -5908,31 +5896,25 @@ dependencies = [ [[package]] name = "sqlx-core" version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d8060b456358185f7d50c55d9b5066ad956956fddec42ee2e8567134a8936e" +source = "git+https://github.com/launchbadge/sqlx.git?branch=main#42ce24dab87aad98f041cafb35cf9a7d5b2b09a7" dependencies = [ - "atoi", - "byteorder", "bytes", "chrono", "crc", "crossbeam-queue", "either", "event-listener 5.3.1", - "futures-channel", "futures-core", "futures-intrusive", "futures-io", "futures-util", "hashbrown 0.14.5", "hashlink", - "hex", "indexmap 2.6.0", "ipnetwork", "log", "memchr", "once_cell", - "paste", "percent-encoding", "rustls", "rustls-pemfile", @@ -5940,8 +5922,7 @@ dependencies = [ "serde_json", "sha2", "smallvec", - "sqlformat", - "thiserror 1.0.69", + "thiserror 2.0.3", "tokio", "tokio-stream", "tracing", @@ -5953,8 +5934,7 @@ dependencies = [ [[package]] name = "sqlx-macros" version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cac0692bcc9de3b073e8d747391827297e075c7710ff6276d9f7a1f3d58c6657" +source = "git+https://github.com/launchbadge/sqlx.git?branch=main#42ce24dab87aad98f041cafb35cf9a7d5b2b09a7" dependencies = [ "proc-macro2", "quote", @@ -5966,8 +5946,7 @@ dependencies = [ [[package]] name = "sqlx-macros-core" version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1804e8a7c7865599c9c79be146dc8a9fd8cc86935fa641d3ea58e5f0688abaa5" +source = "git+https://github.com/launchbadge/sqlx.git?branch=main#42ce24dab87aad98f041cafb35cf9a7d5b2b09a7" dependencies = [ "dotenvy", "either", @@ -5992,8 +5971,7 @@ dependencies = [ [[package]] name = "sqlx-mysql" version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" +source = "git+https://github.com/launchbadge/sqlx.git?branch=main#42ce24dab87aad98f041cafb35cf9a7d5b2b09a7" dependencies = [ "atoi", "base64 0.22.1", @@ -6027,7 +6005,7 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror 1.0.69", + "thiserror 2.0.3", "tracing", "uuid", "whoami", @@ -6036,8 +6014,7 @@ dependencies = [ [[package]] name = "sqlx-postgres" version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" +source = "git+https://github.com/launchbadge/sqlx.git?branch=main#42ce24dab87aad98f041cafb35cf9a7d5b2b09a7" dependencies = [ "atoi", "base64 0.22.1", @@ -6049,7 +6026,6 @@ dependencies = [ "etcetera", "futures-channel", "futures-core", - "futures-io", "futures-util", "hex", "hkdf", @@ -6068,7 +6044,7 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror 1.0.69", + "thiserror 2.0.3", "tracing", "uuid", "whoami", @@ -6077,8 +6053,7 @@ dependencies = [ [[package]] name = "sqlx-sqlite" version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5b2cf34a45953bfd3daaf3db0f7a7878ab9b7a6b91b422d24a7a9e4c857b680" +source = "git+https://github.com/launchbadge/sqlx.git?branch=main#42ce24dab87aad98f041cafb35cf9a7d5b2b09a7" dependencies = [ "atoi", "chrono", @@ -6778,12 +6753,6 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" -[[package]] -name = "unicode_categories" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" - [[package]] name = "universal-hash" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 9d2927037..8d14cafde 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -373,3 +373,10 @@ rayon.opt-level = 3 regalloc2.opt-level = 3 sha2.opt-level = 3 sqlx-macros.opt-level = 3 + +[patch.crates-io] +sqlx = { git = "https://github.com/launchbadge/sqlx.git", branch = "main" } +sqlx-core = { git = "https://github.com/launchbadge/sqlx.git", branch = "main" } +sqlx-macros = { git = "https://github.com/launchbadge/sqlx.git", branch = "main" } +sqlx-macros-core = { git = "https://github.com/launchbadge/sqlx.git", branch = "main" } +sqlx-postgres = { git = "https://github.com/launchbadge/sqlx.git", branch = "main" } diff --git a/deny.toml b/deny.toml index f1dde8f9a..571923c35 100644 --- a/deny.toml +++ b/deny.toml @@ -89,4 +89,4 @@ deny = ["oldtime"] unknown-registry = "warn" unknown-git = "warn" allow-registry = ["https://github.com/rust-lang/crates.io-index"] -allow-git = [] +allow-git = ["https://github.com/launchbadge/sqlx.git"] From f9c8ade312c9d46f950964ec1e64c22419b45e99 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 14 Oct 2024 10:54:55 +0200 Subject: [PATCH 05/17] Graceful shutdown --- Cargo.lock | 1 + crates/cli/src/commands/server.rs | 2 + crates/cli/src/commands/worker.rs | 39 +++- ...ab747a469404533f59ff6fbd56e9eb5ad38e1.json | 14 ++ ...dabe674ea853e0d47eb5c713705cb0130c758.json | 12 ++ crates/storage-pg/src/queue/worker.rs | 26 ++- crates/storage/src/queue/worker.rs | 4 +- crates/tasks/Cargo.toml | 7 +- crates/tasks/src/lib.rs | 9 +- crates/tasks/src/new_queue.rs | 180 +++++++++++++++--- 10 files changed, 250 insertions(+), 44 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-399e261027fe6c9167511636157ab747a469404533f59ff6fbd56e9eb5ad38e1.json create mode 100644 crates/storage-pg/.sqlx/query-6ecad60e565367a6cfa539b4c32dabe674ea853e0d47eb5c713705cb0130c758.json diff --git a/Cargo.lock b/Cargo.lock index a61e07919..d8d66a228 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3758,6 +3758,7 @@ dependencies = [ "sqlx", "thiserror 2.0.3", "tokio", + "tokio-util", "tower 0.5.1", "tracing", "tracing-opentelemetry", diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 0558453d8..70834ccba 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -172,6 +172,8 @@ impl Options { &mailer, homeserver_connection.clone(), url_builder.clone(), + shutdown.soft_shutdown_token(), + shutdown.task_tracker(), ) .await?; diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index 3bbef12dd..c58605a1b 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -17,8 +17,12 @@ use rand::{ }; use tracing::{info, info_span}; -use crate::util::{ - database_pool_from_config, mailer_from_config, site_config_from_config, templates_from_config, +use crate::{ + shutdown::ShutdownManager, + util::{ + database_pool_from_config, mailer_from_config, site_config_from_config, + templates_from_config, + }, }; #[derive(Parser, Debug, Default)] @@ -26,6 +30,7 @@ pub(super) struct Options {} impl Options { pub async fn run(self, figment: &Figment) -> anyhow::Result { + let shutdown = ShutdownManager::new()?; let span = info_span!("cli.worker.init").entered(); let config = AppConfig::extract(figment)?; @@ -71,11 +76,35 @@ impl Options { let worker_name = Alphanumeric.sample_string(&mut rng, 10); info!(worker_name, "Starting task scheduler"); - let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn, url_builder).await?; - + let monitor = mas_tasks::init( + &worker_name, + &pool, + &mailer, + conn, + url_builder, + shutdown.soft_shutdown_token(), + shutdown.task_tracker(), + ) + .await?; + + // XXX: The monitor from apalis is a bit annoying to use for graceful shutdowns, + // ideally we'd just give it a cancellation token + let shutdown_future = shutdown.soft_shutdown_token().cancelled_owned(); + shutdown.task_tracker().spawn(async move { + if let Err(e) = monitor + .run_with_signal(async move { + shutdown_future.await; + Ok(()) + }) + .await + { + tracing::error!(error = &e as &dyn std::error::Error, "Task worker failed"); + } + }); span.exit(); - monitor.run().await?; + shutdown.run().await; + Ok(ExitCode::SUCCESS) } } diff --git a/crates/storage-pg/.sqlx/query-399e261027fe6c9167511636157ab747a469404533f59ff6fbd56e9eb5ad38e1.json b/crates/storage-pg/.sqlx/query-399e261027fe6c9167511636157ab747a469404533f59ff6fbd56e9eb5ad38e1.json new file mode 100644 index 000000000..f0a50a645 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-399e261027fe6c9167511636157ab747a469404533f59ff6fbd56e9eb5ad38e1.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM queue_leader\n WHERE queue_worker_id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "399e261027fe6c9167511636157ab747a469404533f59ff6fbd56e9eb5ad38e1" +} diff --git a/crates/storage-pg/.sqlx/query-6ecad60e565367a6cfa539b4c32dabe674ea853e0d47eb5c713705cb0130c758.json b/crates/storage-pg/.sqlx/query-6ecad60e565367a6cfa539b4c32dabe674ea853e0d47eb5c713705cb0130c758.json new file mode 100644 index 000000000..564800a24 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-6ecad60e565367a6cfa539b4c32dabe674ea853e0d47eb5c713705cb0130c758.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "\n NOTIFY queue_leader_stepdown\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "6ecad60e565367a6cfa539b4c32dabe674ea853e0d47eb5c713705cb0130c758" +} diff --git a/crates/storage-pg/src/queue/worker.rs b/crates/storage-pg/src/queue/worker.rs index 5fa784cb1..5d3c566b2 100644 --- a/crates/storage-pg/src/queue/worker.rs +++ b/crates/storage-pg/src/queue/worker.rs @@ -109,7 +109,7 @@ impl QueueWorkerRepository for PgQueueWorkerRepository<'_> { ), err, )] - async fn shutdown(&mut self, clock: &dyn Clock, worker: Worker) -> Result<(), Self::Error> { + async fn shutdown(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> { let now = clock.now(); let res = sqlx::query!( r#" @@ -126,6 +126,30 @@ impl QueueWorkerRepository for PgQueueWorkerRepository<'_> { DatabaseError::ensure_affected_rows(&res, 1)?; + // Remove the leader lease if we were holding it + let res = sqlx::query!( + r#" + DELETE FROM queue_leader + WHERE queue_worker_id = $1 + "#, + Uuid::from(worker.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + // If we were holding the leader lease, notify workers + if res.rows_affected() > 0 { + sqlx::query!( + r#" + NOTIFY queue_leader_stepdown + "#, + ) + .traced() + .execute(&mut *self.conn) + .await?; + } + Ok(()) } diff --git a/crates/storage/src/queue/worker.rs b/crates/storage/src/queue/worker.rs index 19ceead88..4916ec1fd 100644 --- a/crates/storage/src/queue/worker.rs +++ b/crates/storage/src/queue/worker.rs @@ -51,7 +51,7 @@ pub trait QueueWorkerRepository: Send + Sync { /// # Errors /// /// Returns an error if the underlying repository fails. - async fn shutdown(&mut self, clock: &dyn Clock, worker: Worker) -> Result<(), Self::Error>; + async fn shutdown(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error>; /// Find dead workers and shut them down. /// @@ -105,7 +105,7 @@ repository_impl!(QueueWorkerRepository: async fn shutdown( &mut self, clock: &dyn Clock, - worker: Worker, + worker: &Worker, ) -> Result<(), Self::Error>; async fn shutdown_dead_workers( diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 8aff0da36..763b0f5fc 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -13,7 +13,11 @@ workspace = true [dependencies] anyhow.workspace = true -apalis-core = { version = "0.4.9", features = ["extensions", "tokio-comp", "storage"] } +apalis-core = { version = "0.4.9", features = [ + "extensions", + "tokio-comp", + "storage", +] } apalis-cron = "0.4.9" async-stream = "0.3.6" async-trait.workspace = true @@ -25,6 +29,7 @@ rand_chacha = "0.3.1" sqlx.workspace = true thiserror.workspace = true tokio.workspace = true +tokio-util.workspace = true tower.workspace = true tracing.workspace = true tracing-opentelemetry.workspace = true diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 0db6b6b81..ac85eba14 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -15,6 +15,7 @@ use mas_storage_pg::PgRepository; use new_queue::QueueRunnerError; use rand::SeedableRng; use sqlx::{Pool, Postgres}; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tracing::debug; use crate::storage::PostgresStorageFactory; @@ -143,6 +144,8 @@ pub async fn init( mailer: &Mailer, homeserver: impl HomeserverConnection + 'static, url_builder: UrlBuilder, + cancellation_token: CancellationToken, + task_tracker: &TaskTracker, ) -> Result, QueueRunnerError> { let state = State::new( pool.clone(), @@ -166,11 +169,9 @@ pub async fn init( .map_err(QueueRunnerError::SetupListener)?; debug!(?monitor, "workers registered"); - let mut worker = self::new_queue::QueueWorker::new(state).await?; + let mut worker = self::new_queue::QueueWorker::new(state, cancellation_token).await?; - // TODO: this is just spawning the task in the background, we probably actually - // want to wrap that in a structure, and handle graceful shutdown correctly - tokio::spawn(async move { + task_tracker.spawn(async move { if let Err(e) = worker.run().await { tracing::error!( error = &e as &dyn std::error::Error, diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index c3596c0e9..571f8591b 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -8,8 +8,12 @@ use mas_storage::{queue::Worker, Clock, RepositoryAccess, RepositoryError}; use mas_storage_pg::{DatabaseError, PgRepository}; use rand::{distributions::Uniform, Rng}; use rand_chacha::ChaChaRng; -use sqlx::PgPool; +use sqlx::{ + postgres::{PgAdvisoryLock, PgListener}, + Acquire, Either, +}; use thiserror::Error; +use tokio_util::sync::CancellationToken; use crate::State; @@ -24,6 +28,9 @@ pub enum QueueRunnerError { #[error("Failed to commit transaction")] CommitTransaction(#[source] sqlx::Error), + #[error("Failed to acquire leader lock")] + LeaderLock(#[source] sqlx::Error), + #[error(transparent)] Repository(#[from] RepositoryError), @@ -34,16 +41,21 @@ pub enum QueueRunnerError { NotLeader, } +// When the worker waits for a notification, we still want to wake it up every +// second. Because we don't want all the workers to wake up at the same time, we +// add a random jitter to the sleep duration, so they effectively sleep between +// 0.9 and 1.1 seconds. const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900); const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100); pub struct QueueWorker { rng: ChaChaRng, clock: Box, - pool: PgPool, + listener: PgListener, registration: Worker, am_i_leader: bool, last_heartbeat: DateTime, + cancellation_token: CancellationToken, } impl QueueWorker { @@ -52,12 +64,24 @@ impl QueueWorker { skip_all, fields(worker.id) )] - pub async fn new(state: State) -> Result { + pub async fn new( + state: State, + cancellation_token: CancellationToken, + ) -> Result { let mut rng = state.rng(); let clock = state.clock(); - let pool = state.pool().clone(); - let txn = pool + let mut listener = PgListener::connect_with(state.pool()) + .await + .map_err(QueueRunnerError::SetupListener)?; + + // We get notifications of leader stepping down on this channel + listener + .listen("queue_leader_stepdown") + .await + .map_err(QueueRunnerError::SetupListener)?; + + let txn = listener .begin() .await .map_err(QueueRunnerError::StartTransaction)?; @@ -76,22 +100,32 @@ impl QueueWorker { Ok(Self { rng, clock, - pool, + listener, registration, am_i_leader: false, last_heartbeat: now, + cancellation_token, }) } pub async fn run(&mut self) -> Result<(), QueueRunnerError> { - loop { + while !self.cancellation_token.is_cancelled() { self.run_loop().await?; } + + self.shutdown().await?; + + Ok(()) } #[tracing::instrument(name = "worker.run_loop", skip_all, err)] async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { self.wait_until_wakeup().await?; + + if self.cancellation_token.is_cancelled() { + return Ok(()); + } + self.tick().await?; if self.am_i_leader { @@ -101,6 +135,33 @@ impl QueueWorker { Ok(()) } + #[tracing::instrument(name = "worker.shutdown", skip_all, err)] + async fn shutdown(&mut self) -> Result<(), QueueRunnerError> { + tracing::info!("Shutting down worker"); + + // Start a transaction on the existing PgListener connection + let txn = self + .listener + .begin() + .await + .map_err(QueueRunnerError::StartTransaction)?; + + let mut repo = PgRepository::from_conn(txn); + + // Tell the other workers we're shutting down + // This also releases the leader election lease + repo.queue_worker() + .shutdown(&self.clock, &self.registration) + .await?; + + repo.into_inner() + .commit() + .await + .map_err(QueueRunnerError::CommitTransaction)?; + + Ok(()) + } + #[tracing::instrument(name = "worker.wait_until_wakeup", skip_all, err)] async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> { // This is to make sure we wake up every second to do the maintenance tasks @@ -109,25 +170,34 @@ impl QueueWorker { let sleep_duration = self .rng .sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION)); - tokio::time::sleep(sleep_duration).await; - tracing::debug!("Woke up from sleep"); - - Ok(()) - } - - fn set_new_leader_state(&mut self, state: bool) { - // Do nothing if we were already on that state - if state == self.am_i_leader { - return; + let wakeup_sleep = tokio::time::sleep(sleep_duration); + + tokio::select! { + () = self.cancellation_token.cancelled() => { + tracing::debug!("Woke up from cancellation"); + }, + + () = wakeup_sleep => { + tracing::debug!("Woke up from sleep"); + }, + + notification = self.listener.recv() => { + match notification { + Ok(notification) => { + tracing::debug!( + notification.channel = notification.channel(), + notification.payload = notification.payload(), + "Woke up from notification" + ); + }, + Err(e) => { + tracing::error!(error = &e as &dyn std::error::Error, "Failed to receive notification"); + }, + } + }, } - // If we flipped state, log it - self.am_i_leader = state; - if self.am_i_leader { - tracing::info!("I'm the leader now"); - } else { - tracing::warn!("I am no longer the leader"); - } + Ok(()) } #[tracing::instrument( @@ -140,8 +210,9 @@ impl QueueWorker { tracing::debug!("Tick"); let now = self.clock.now(); + // Start a transaction on the existing PgListener connection let txn = self - .pool + .listener .begin() .await .map_err(QueueRunnerError::StartTransaction)?; @@ -168,13 +239,23 @@ impl QueueWorker { .try_get_leader_lease(&self.clock, &self.registration) .await?; + // After this point, we are locking the leader table, so it's important that we + // commit as soon as possible to not block the other workers for too long repo.into_inner() .commit() .await .map_err(QueueRunnerError::CommitTransaction)?; - // Save the new leader state - self.set_new_leader_state(leader); + // Save the new leader state to log any change + if leader != self.am_i_leader { + // If we flipped state, log it + self.am_i_leader = leader; + if self.am_i_leader { + tracing::info!("I'm the leader now"); + } else { + tracing::warn!("I am no longer the leader"); + } + } Ok(()) } @@ -186,12 +267,43 @@ impl QueueWorker { return Err(QueueRunnerError::NotLeader); } + // Start a transaction on the existing PgListener connection let txn = self - .pool + .listener .begin() .await .map_err(QueueRunnerError::StartTransaction)?; - let mut repo = PgRepository::from_conn(txn); + + // The thing with the leader election is that it locks the table during the + // election, preventing other workers from going through the loop. + // + // Ideally, we would do the leader duties in the same transaction so that we + // make sure only one worker is doing the leader duties, but that + // would mean we would lock all the workers for the duration of the + // duties, which is not ideal. + // + // So we do the duties in a separate transaction, in which we take an advisory + // lock, so that in the very rare case where two workers think they are the + // leader, we still don't have two workers doing the duties at the same time. + let lock = PgAdvisoryLock::new("leader-duties"); + + let locked = lock + .try_acquire(txn) + .await + .map_err(QueueRunnerError::LeaderLock)?; + + let locked = match locked { + Either::Left(locked) => locked, + Either::Right(txn) => { + tracing::error!("Another worker has the leader lock, aborting"); + txn.rollback() + .await + .map_err(QueueRunnerError::CommitTransaction)?; + return Ok(()); + } + }; + + let mut repo = PgRepository::from_conn(locked); // We also check if the worker is dead, and if so, we shutdown all the dead // workers that haven't checked in the last two minutes @@ -199,8 +311,14 @@ impl QueueWorker { .shutdown_dead_workers(&self.clock, Duration::minutes(2)) .await?; - repo.into_inner() - .commit() + // Release the leader lock + let txn = repo + .into_inner() + .release_now() + .await + .map_err(QueueRunnerError::LeaderLock)?; + + txn.commit() .await .map_err(QueueRunnerError::CommitTransaction)?; From 2692d9a28f9d8aa38090720b332f35eb08804da0 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 19 Nov 2024 17:23:55 +0100 Subject: [PATCH 06/17] Use the database time for leader election --- ...4d55cbfb01492985ac2af5a1ad4af9b3ccc77.json | 15 ++++++++++++++ ...d356a4ed86fd33400066e422545ffc55f9aa9.json | 16 --------------- ...f3005b55c654897a8e46dc933c7fd2263c7c.json} | 8 +++----- crates/storage-pg/src/queue/worker.rs | 20 +++++++++++-------- 4 files changed, 30 insertions(+), 29 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-67cd4880d84b38f20c3960789934d55cbfb01492985ac2af5a1ad4af9b3ccc77.json delete mode 100644 crates/storage-pg/.sqlx/query-8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9.json rename crates/storage-pg/.sqlx/{query-ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5.json => query-fe7bd146523e4bb321cb234d6bf9f3005b55c654897a8e46dc933c7fd2263c7c.json} (51%) diff --git a/crates/storage-pg/.sqlx/query-67cd4880d84b38f20c3960789934d55cbfb01492985ac2af5a1ad4af9b3ccc77.json b/crates/storage-pg/.sqlx/query-67cd4880d84b38f20c3960789934d55cbfb01492985ac2af5a1ad4af9b3ccc77.json new file mode 100644 index 000000000..1d739df31 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-67cd4880d84b38f20c3960789934d55cbfb01492985ac2af5a1ad4af9b3ccc77.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id)\n VALUES ($1, NOW() + INTERVAL '5 seconds', $2)\n ON CONFLICT (active)\n DO UPDATE SET expires_at = EXCLUDED.expires_at\n WHERE queue_leader.queue_worker_id = $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "67cd4880d84b38f20c3960789934d55cbfb01492985ac2af5a1ad4af9b3ccc77" +} diff --git a/crates/storage-pg/.sqlx/query-8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9.json b/crates/storage-pg/.sqlx/query-8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9.json deleted file mode 100644 index 9195a9d4d..000000000 --- a/crates/storage-pg/.sqlx/query-8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id)\n VALUES ($1, $2, $3)\n ON CONFLICT (active)\n DO UPDATE SET expires_at = EXCLUDED.expires_at\n WHERE queue_leader.queue_worker_id = $3\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Timestamptz", - "Timestamptz", - "Uuid" - ] - }, - "nullable": [] - }, - "hash": "8defee03b9ed60c2b8cc6478e34d356a4ed86fd33400066e422545ffc55f9aa9" -} diff --git a/crates/storage-pg/.sqlx/query-ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5.json b/crates/storage-pg/.sqlx/query-fe7bd146523e4bb321cb234d6bf9f3005b55c654897a8e46dc933c7fd2263c7c.json similarity index 51% rename from crates/storage-pg/.sqlx/query-ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5.json rename to crates/storage-pg/.sqlx/query-fe7bd146523e4bb321cb234d6bf9f3005b55c654897a8e46dc933c7fd2263c7c.json index af6213a8a..9cfabe8ed 100644 --- a/crates/storage-pg/.sqlx/query-ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5.json +++ b/crates/storage-pg/.sqlx/query-fe7bd146523e4bb321cb234d6bf9f3005b55c654897a8e46dc933c7fd2263c7c.json @@ -1,14 +1,12 @@ { "db_name": "PostgreSQL", - "query": "\n DELETE FROM queue_leader\n WHERE expires_at < $1\n ", + "query": "\n DELETE FROM queue_leader\n WHERE expires_at < NOW()\n ", "describe": { "columns": [], "parameters": { - "Left": [ - "Timestamptz" - ] + "Left": [] }, "nullable": [] }, - "hash": "ed8bcbcd4b7f93f654670cf077f84163ae08e16a8e07c6ecbca4fd8cb10da8a5" + "hash": "fe7bd146523e4bb321cb234d6bf9f3005b55c654897a8e46dc933c7fd2263c7c" } diff --git a/crates/storage-pg/src/queue/worker.rs b/crates/storage-pg/src/queue/worker.rs index 5d3c566b2..61b96b5fd 100644 --- a/crates/storage-pg/src/queue/worker.rs +++ b/crates/storage-pg/src/queue/worker.rs @@ -166,6 +166,9 @@ impl QueueWorkerRepository for PgQueueWorkerRepository<'_> { clock: &dyn Clock, threshold: Duration, ) -> Result<(), Self::Error> { + // Here the threshold is usually set to a few minutes, so we don't need to use + // the database time, as we can assume worker clocks have less than a minute + // skew between each other, else other things would break let now = clock.now(); sqlx::query!( r#" @@ -194,15 +197,15 @@ impl QueueWorkerRepository for PgQueueWorkerRepository<'_> { )] async fn remove_leader_lease_if_expired( &mut self, - clock: &dyn Clock, + _clock: &dyn Clock, ) -> Result<(), Self::Error> { - let now = clock.now(); + // `expires_at` is a rare exception where we use the database time, as this + // would be very sensitive to clock skew between workers sqlx::query!( r#" DELETE FROM queue_leader - WHERE expires_at < $1 + WHERE expires_at < NOW() "#, - now, ) .traced() .execute(&mut *self.conn) @@ -226,22 +229,23 @@ impl QueueWorkerRepository for PgQueueWorkerRepository<'_> { worker: &Worker, ) -> Result { let now = clock.now(); - let ttl = Duration::seconds(5); // The queue_leader table is meant to only have a single row, which conflicts on // the `active` column // If there is a conflict, we update the `expires_at` column ONLY IF the current // leader is ourselves. + + // `expires_at` is a rare exception where we use the database time, as this + // would be very sensitive to clock skew between workers let res = sqlx::query!( r#" INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id) - VALUES ($1, $2, $3) + VALUES ($1, NOW() + INTERVAL '5 seconds', $2) ON CONFLICT (active) DO UPDATE SET expires_at = EXCLUDED.expires_at - WHERE queue_leader.queue_worker_id = $3 + WHERE queue_leader.queue_worker_id = $2 "#, now, - now + ttl, Uuid::from(worker.id) ) .traced() From 874ae397159cfcf321877ca043623e50b9882120 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 14 Oct 2024 10:29:56 +0200 Subject: [PATCH 07/17] WIP jobs --- ...293c35ba4503855d52a5b62b6e86b126362f5.json | 37 +++++ ...8e601f12c8003fe93a5ecb110d02642d14c3c.json | 18 +++ .../migrations/20241004121132_queue_job.sql | 79 +++++++++ crates/storage-pg/src/queue/job.rs | 150 ++++++++++++++++++ crates/storage-pg/src/queue/mod.rs | 1 + crates/storage/src/queue/job.rs | 105 ++++++++++++ crates/storage/src/queue/mod.rs | 6 +- 7 files changed, 395 insertions(+), 1 deletion(-) create mode 100644 crates/storage-pg/.sqlx/query-0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5.json create mode 100644 crates/storage-pg/.sqlx/query-e291be0434ab9c346dee777e50f8e601f12c8003fe93a5ecb110d02642d14c3c.json create mode 100644 crates/storage-pg/migrations/20241004121132_queue_job.sql create mode 100644 crates/storage-pg/src/queue/job.rs create mode 100644 crates/storage/src/queue/job.rs diff --git a/crates/storage-pg/.sqlx/query-0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5.json b/crates/storage-pg/.sqlx/query-0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5.json new file mode 100644 index 000000000..6488ff09c --- /dev/null +++ b/crates/storage-pg/.sqlx/query-0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5.json @@ -0,0 +1,37 @@ +{ + "db_name": "PostgreSQL", + "query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.payload,\n queue_jobs.metadata\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "queue_job_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "payload", + "type_info": "Jsonb" + }, + { + "ordinal": 2, + "name": "metadata", + "type_info": "Jsonb" + } + ], + "parameters": { + "Left": [ + "TextArray", + "Int8", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5" +} diff --git a/crates/storage-pg/.sqlx/query-e291be0434ab9c346dee777e50f8e601f12c8003fe93a5ecb110d02642d14c3c.json b/crates/storage-pg/.sqlx/query-e291be0434ab9c346dee777e50f8e601f12c8003fe93a5ecb110d02642d14c3c.json new file mode 100644 index 000000000..84ac12de9 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-e291be0434ab9c346dee777e50f8e601f12c8003fe93a5ecb110d02642d14c3c.json @@ -0,0 +1,18 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at)\n VALUES ($1, $2, $3, $4, $5)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Jsonb", + "Jsonb", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "e291be0434ab9c346dee777e50f8e601f12c8003fe93a5ecb110d02642d14c3c" +} diff --git a/crates/storage-pg/migrations/20241004121132_queue_job.sql b/crates/storage-pg/migrations/20241004121132_queue_job.sql new file mode 100644 index 000000000..859377d52 --- /dev/null +++ b/crates/storage-pg/migrations/20241004121132_queue_job.sql @@ -0,0 +1,79 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +CREATE TYPE queue_job_status AS ENUM ( + -- The job is available to be picked up by a worker + 'available', + + -- The job is currently being processed by a worker + 'running', + + -- The job has been completed + 'completed', + + -- The worker running the job was lost + 'lost' +); + +CREATE TABLE queue_jobs ( + queue_job_id UUID NOT NULL PRIMARY KEY, + + -- The status of the job + status queue_job_status NOT NULL DEFAULT 'available', + + -- When the job was created + created_at TIMESTAMP WITH TIME ZONE NOT NULL, + + -- When the job was grabbed by a worker + started_at TIMESTAMP WITH TIME ZONE, + + -- Which worker is currently processing the job + started_by UUID REFERENCES queue_workers (queue_worker_id), + + -- When the job was completed + completed_at TIMESTAMP WITH TIME ZONE, + + -- The name of the queue this job belongs to + queue_name TEXT NOT NULL, + + -- The arguments to the job + payload JSONB NOT NULL DEFAULT '{}', + + -- Arbitrary metadata about the job, like the trace context + metadata JSONB NOT NULL DEFAULT '{}' +); + +-- When we grab jobs, we filter on the status of the job and the queue name +-- Then we order on the `queue_job_id` column, as it is a ULID, which ensures timestamp ordering +CREATE INDEX idx_queue_jobs_status_queue_job_id + ON queue_jobs + USING BTREE (status, queue_name, queue_job_id); + +-- We would like to notify workers when a job is available to wake them up +CREATE OR REPLACE FUNCTION queue_job_notify() + RETURNS TRIGGER + AS $$ +DECLARE + payload json; +BEGIN + IF NEW.status = 'available' THEN + -- The idea with this trigger is to notify the queue worker that a new job + -- is available on a queue. If there are many notifications with the same + -- payload, PG will coalesce them in a single notification, which is why we + -- keep the payload simple. + payload = json_build_object('queue', NEW.queue_name); + PERFORM + pg_notify('queue_available', payload::text); + END IF; + RETURN NULL; +END; +$$ +LANGUAGE plpgsql; + +CREATE TRIGGER queue_job_notify_trigger + AFTER INSERT OR UPDATE OF status + ON queue_jobs + FOR EACH ROW + EXECUTE PROCEDURE queue_job_notify(); diff --git a/crates/storage-pg/src/queue/job.rs b/crates/storage-pg/src/queue/job.rs new file mode 100644 index 000000000..47185243e --- /dev/null +++ b/crates/storage-pg/src/queue/job.rs @@ -0,0 +1,150 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! A module containing the PostgreSQL implementation of the +//! [`QueueJobRepository`]. + +use async_trait::async_trait; +use mas_storage::{ + queue::{Job, QueueJobRepository, Worker}, + Clock, +}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{DatabaseError, ExecuteExt}; + +/// An implementation of [`QueueJobRepository`] for a PostgreSQL connection. +pub struct PgQueueJobRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgQueueJobRepository<'c> { + /// Create a new [`PgQueueJobRepository`] from an active PostgreSQL + /// connection. + #[must_use] + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[async_trait] +impl QueueJobRepository for PgQueueJobRepository<'_> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.queue_job.schedule", + fields( + queue_job.id, + queue_job.queue_name = queue_name, + db.query.text, + ), + skip_all, + err, + )] + async fn schedule( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + queue_name: &str, + payload: serde_json::Value, + metadata: serde_json::Value, + ) -> Result<(), Self::Error> { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("queue_job.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO queue_jobs + (queue_job_id, queue_name, payload, metadata, created_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + queue_name, + payload, + metadata, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.queue_job.get_available", + fields( + db.query.text, + ), + skip_all, + err, + )] + async fn get_available( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + max_count: usize, + ) -> Result, Self::Error> { + let now = clock.now(); + let max_count = i64::try_from(max_count).unwrap_or(i64::MAX); + let queues: Vec = queues.iter().map(|&s| s.to_owned()).collect(); + sqlx::query!( + r#" + -- We first grab a few jobs that are available, + -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently + -- and we don't get multiple workers grabbing the same jobs + WITH locked_jobs AS ( + SELECT queue_job_id + FROM queue_jobs + WHERE + status = 'available' + AND queue_name = ANY($1) + ORDER BY queue_job_id ASC + LIMIT $2 + FOR UPDATE + SKIP LOCKED + ) + -- then we update the status of those jobs to 'running', returning the job details + UPDATE queue_jobs + SET status = 'running', started_at = $3, started_by = $4 + FROM locked_jobs + WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id + RETURNING + queue_jobs.queue_job_id, + queue_jobs.payload, + queue_jobs.metadata + "#, + &queues, + max_count, + now, + Uuid::from(worker.id), + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + todo!() + } + + #[tracing::instrument( + name = "db.queue_job.mark_completed", + fields( + queue_job.id = %job.id, + db.query.text, + ), + skip_all, + err, + )] + async fn mark_completed(&mut self, clock: &dyn Clock, job: Job) -> Result<(), Self::Error> { + let _ = clock; + let _ = job; + todo!() + } +} diff --git a/crates/storage-pg/src/queue/mod.rs b/crates/storage-pg/src/queue/mod.rs index b6ba8295e..eca02b809 100644 --- a/crates/storage-pg/src/queue/mod.rs +++ b/crates/storage-pg/src/queue/mod.rs @@ -5,4 +5,5 @@ //! A module containing the PostgreSQL implementation of the job queue +pub mod job; pub mod worker; diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs new file mode 100644 index 000000000..a96780ff5 --- /dev/null +++ b/crates/storage/src/queue/job.rs @@ -0,0 +1,105 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! Repository to interact with jobs in the job queue + +use async_trait::async_trait; +use rand_core::RngCore; +use ulid::Ulid; + +use super::Worker; +use crate::{repository_impl, Clock}; + +enum JobState { + /// The job is available to be picked up by a worker + Available, + + /// The job is currently being processed by a worker + Running, + + /// The job has been completed + Completed, + + /// The worker running the job was lost + Lost, +} + +/// Represents a job in the job queue +pub struct Job { + /// The ID of the job + pub id: Ulid, +} + +/// A [`QueueJobRepository`] is used to schedule jobs to be executed by a +/// worker. +#[async_trait] +pub trait QueueJobRepository: Send + Sync { + /// The error type returned by the repository. + type Error; + + /// Schedule a job to be executed as soon as possible by a worker. + /// + /// # Parameters + /// + /// * `rng` - The random number generator used to generate a new job ID + /// * `clock` - The clock used to generate timestamps + /// * `queue_name` - The name of the queue to schedule the job on + /// * `payload` - The payload of the job + /// * `metadata` - Arbitrary metadata about the job scheduled immediately. + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn schedule( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + queue_name: &str, + payload: serde_json::Value, + metadata: serde_json::Value, + ) -> Result<(), Self::Error>; + + /// Get and lock a batch of jobs that are ready to be executed. + /// This will transition them to a [`JobState::Running`] state. + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn get_available( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + max_count: usize, + ) -> Result, Self::Error>; + + /// Mark the given job as completed. + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn mark_completed(&mut self, clock: &dyn Clock, job: Job) -> Result<(), Self::Error>; +} + +repository_impl!(QueueJobRepository: + async fn schedule( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + queue_name: &str, + payload: serde_json::Value, + metadata: serde_json::Value, + ) -> Result<(), Self::Error>; + + async fn get_available( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + max_count: usize, + ) -> Result, Self::Error>; + + async fn mark_completed(&mut self, clock: &dyn Clock, job: Job) -> Result<(), Self::Error>; +); diff --git a/crates/storage/src/queue/mod.rs b/crates/storage/src/queue/mod.rs index 4ca97ec5e..a9757aed1 100644 --- a/crates/storage/src/queue/mod.rs +++ b/crates/storage/src/queue/mod.rs @@ -5,6 +5,10 @@ //! A module containing repositories for the job queue +mod job; mod worker; -pub use self::worker::{QueueWorkerRepository, Worker}; +pub use self::{ + job::{Job, QueueJobRepository}, + worker::{QueueWorkerRepository, Worker}, +}; From f683bc989fb2721b091721c217ec5acae4af0961 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 15 Oct 2024 13:33:07 +0200 Subject: [PATCH 08/17] Move the jobs types in the queue module --- crates/cli/src/commands/manage.rs | 5 +- crates/handlers/src/admin/v1/users/add.rs | 5 +- .../handlers/src/admin/v1/users/deactivate.rs | 2 +- crates/handlers/src/compat/logout.rs | 3 +- .../src/graphql/mutations/compat_session.rs | 4 +- .../src/graphql/mutations/oauth2_session.rs | 3 +- crates/handlers/src/graphql/mutations/user.rs | 3 +- .../src/graphql/mutations/user_email.rs | 3 +- crates/handlers/src/oauth2/revoke.rs | 3 +- crates/handlers/src/upstream_oauth2/link.rs | 3 +- .../handlers/src/views/account/emails/add.rs | 5 +- .../src/views/account/emails/verify.rs | 5 +- .../handlers/src/views/recovery/progress.rs | 3 +- crates/handlers/src/views/recovery/start.rs | 3 +- crates/handlers/src/views/register.rs | 3 +- ...293c35ba4503855d52a5b62b6e86b126362f5.json | 37 -- crates/storage-pg/src/queue/job.rs | 76 +--- crates/storage/src/job.rs | 293 ---------------- crates/storage/src/queue/job.rs | 145 +++++--- crates/storage/src/queue/mod.rs | 4 +- crates/storage/src/queue/tasks.rs | 330 ++++++++++++++++++ crates/tasks/src/email.rs | 2 +- crates/tasks/src/matrix.rs | 6 +- crates/tasks/src/recovery.rs | 3 +- crates/tasks/src/user.rs | 3 +- 25 files changed, 469 insertions(+), 483 deletions(-) delete mode 100644 crates/storage-pg/.sqlx/query-0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5.json create mode 100644 crates/storage/src/queue/tasks.rs diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index c2fe37a10..e0891ecec 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -20,10 +20,9 @@ use mas_matrix::HomeserverConnection; use mas_matrix_synapse::SynapseConnection; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository}, - job::{ - DeactivateUserJob, JobRepositoryExt, ProvisionUserJob, ReactivateUserJob, SyncDevicesJob, - }, + job::JobRepositoryExt, oauth2::OAuth2SessionFilter, + queue::{DeactivateUserJob, ProvisionUserJob, ReactivateUserJob, SyncDevicesJob}, user::{BrowserSessionFilter, UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, RepositoryAccess, SystemClock, }; diff --git a/crates/handlers/src/admin/v1/users/add.rs b/crates/handlers/src/admin/v1/users/add.rs index d45e184bc..33fc299f2 100644 --- a/crates/handlers/src/admin/v1/users/add.rs +++ b/crates/handlers/src/admin/v1/users/add.rs @@ -8,10 +8,7 @@ use aide::{transform::TransformOperation, NoApi, OperationIo}; use axum::{extract::State, response::IntoResponse, Json}; use hyper::StatusCode; use mas_matrix::BoxHomeserverConnection; -use mas_storage::{ - job::{JobRepositoryExt, ProvisionUserJob}, - BoxRng, -}; +use mas_storage::{job::JobRepositoryExt, queue::ProvisionUserJob, BoxRng}; use schemars::JsonSchema; use serde::Deserialize; use tracing::warn; diff --git a/crates/handlers/src/admin/v1/users/deactivate.rs b/crates/handlers/src/admin/v1/users/deactivate.rs index 61afcc3a7..091116c8b 100644 --- a/crates/handlers/src/admin/v1/users/deactivate.rs +++ b/crates/handlers/src/admin/v1/users/deactivate.rs @@ -7,7 +7,7 @@ use aide::{transform::TransformOperation, OperationIo}; use axum::{response::IntoResponse, Json}; use hyper::StatusCode; -use mas_storage::job::{DeactivateUserJob, JobRepositoryExt}; +use mas_storage::{job::JobRepositoryExt, queue::DeactivateUserJob}; use tracing::info; use ulid::Ulid; diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 59826f0ba..df41a4ad7 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -12,7 +12,8 @@ use mas_axum_utils::sentry::SentryEventID; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - job::{JobRepositoryExt, SyncDevicesJob}, + job::JobRepositoryExt, + queue::SyncDevicesJob, BoxClock, BoxRepository, Clock, RepositoryAccess, }; use thiserror::Error; diff --git a/crates/handlers/src/graphql/mutations/compat_session.rs b/crates/handlers/src/graphql/mutations/compat_session.rs index baf385994..bea5d90d5 100644 --- a/crates/handlers/src/graphql/mutations/compat_session.rs +++ b/crates/handlers/src/graphql/mutations/compat_session.rs @@ -7,9 +7,7 @@ use anyhow::Context as _; use async_graphql::{Context, Enum, InputObject, Object, ID}; use mas_storage::{ - compat::CompatSessionRepository, - job::{JobRepositoryExt, SyncDevicesJob}, - RepositoryAccess, + compat::CompatSessionRepository, job::JobRepositoryExt, queue::SyncDevicesJob, RepositoryAccess, }; use crate::graphql::{ diff --git a/crates/handlers/src/graphql/mutations/oauth2_session.rs b/crates/handlers/src/graphql/mutations/oauth2_session.rs index a0595e695..0dcbd894a 100644 --- a/crates/handlers/src/graphql/mutations/oauth2_session.rs +++ b/crates/handlers/src/graphql/mutations/oauth2_session.rs @@ -9,11 +9,12 @@ use async_graphql::{Context, Description, Enum, InputObject, Object, ID}; use chrono::Duration; use mas_data_model::{Device, TokenType}; use mas_storage::{ - job::{JobRepositoryExt, SyncDevicesJob}, + job::JobRepositoryExt, oauth2::{ OAuth2AccessTokenRepository, OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, + queue::SyncDevicesJob, user::UserRepository, RepositoryAccess, }; diff --git a/crates/handlers/src/graphql/mutations/user.rs b/crates/handlers/src/graphql/mutations/user.rs index 3fd4a10d5..9cfe22545 100644 --- a/crates/handlers/src/graphql/mutations/user.rs +++ b/crates/handlers/src/graphql/mutations/user.rs @@ -7,7 +7,8 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Enum, InputObject, Object, ID}; use mas_storage::{ - job::{DeactivateUserJob, JobRepositoryExt, ProvisionUserJob}, + job::JobRepositoryExt, + queue::{DeactivateUserJob, ProvisionUserJob}, user::UserRepository, }; use tracing::{info, warn}; diff --git a/crates/handlers/src/graphql/mutations/user_email.rs b/crates/handlers/src/graphql/mutations/user_email.rs index 91a963dc7..057b18919 100644 --- a/crates/handlers/src/graphql/mutations/user_email.rs +++ b/crates/handlers/src/graphql/mutations/user_email.rs @@ -7,7 +7,8 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Enum, InputObject, Object, ID}; use mas_storage::{ - job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob}, + job::JobRepositoryExt, + queue::{ProvisionUserJob, VerifyEmailJob}, user::{UserEmailRepository, UserRepository}, RepositoryAccess, }; diff --git a/crates/handlers/src/oauth2/revoke.rs b/crates/handlers/src/oauth2/revoke.rs index 584a02053..83abf8147 100644 --- a/crates/handlers/src/oauth2/revoke.rs +++ b/crates/handlers/src/oauth2/revoke.rs @@ -14,8 +14,7 @@ use mas_data_model::TokenType; use mas_iana::oauth::OAuthTokenTypeHint; use mas_keystore::Encrypter; use mas_storage::{ - job::{JobRepositoryExt, SyncDevicesJob}, - BoxClock, BoxRepository, RepositoryAccess, + job::JobRepositoryExt, queue::SyncDevicesJob, BoxClock, BoxRepository, RepositoryAccess, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index aea5d32b7..f7d6e2bf9 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -23,7 +23,8 @@ use mas_matrix::BoxHomeserverConnection; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ - job::{JobRepositoryExt, ProvisionUserJob}, + job::JobRepositoryExt, + queue::ProvisionUserJob, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserRepository}, BoxClock, BoxRepository, BoxRng, RepositoryAccess, diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 8ea14474d..dbb5dba1d 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -17,9 +17,8 @@ use mas_data_model::SiteConfig; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ - job::{JobRepositoryExt, VerifyEmailJob}, - user::UserEmailRepository, - BoxClock, BoxRepository, BoxRng, + job::JobRepositoryExt, queue::VerifyEmailJob, user::UserEmailRepository, BoxClock, + BoxRepository, BoxRng, }; use mas_templates::{EmailAddContext, ErrorContext, TemplateContext, Templates}; use serde::Deserialize; diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index a358f6d0a..a25b1028b 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -16,9 +16,8 @@ use mas_axum_utils::{ }; use mas_router::UrlBuilder; use mas_storage::{ - job::{JobRepositoryExt, ProvisionUserJob}, - user::UserEmailRepository, - BoxClock, BoxRepository, BoxRng, RepositoryAccess, + job::JobRepositoryExt, queue::ProvisionUserJob, user::UserEmailRepository, BoxClock, + BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; diff --git a/crates/handlers/src/views/recovery/progress.rs b/crates/handlers/src/views/recovery/progress.rs index e4d416691..c0a5519df 100644 --- a/crates/handlers/src/views/recovery/progress.rs +++ b/crates/handlers/src/views/recovery/progress.rs @@ -18,8 +18,7 @@ use mas_axum_utils::{ use mas_data_model::SiteConfig; use mas_router::UrlBuilder; use mas_storage::{ - job::{JobRepositoryExt, SendAccountRecoveryEmailsJob}, - BoxClock, BoxRepository, BoxRng, + job::JobRepositoryExt, queue::SendAccountRecoveryEmailsJob, BoxClock, BoxRepository, BoxRng, }; use mas_templates::{EmptyContext, RecoveryProgressContext, TemplateContext, Templates}; use ulid::Ulid; diff --git a/crates/handlers/src/views/recovery/start.rs b/crates/handlers/src/views/recovery/start.rs index 53896a427..e9cbc758f 100644 --- a/crates/handlers/src/views/recovery/start.rs +++ b/crates/handlers/src/views/recovery/start.rs @@ -21,8 +21,7 @@ use mas_axum_utils::{ use mas_data_model::{SiteConfig, UserAgent}; use mas_router::UrlBuilder; use mas_storage::{ - job::{JobRepositoryExt, SendAccountRecoveryEmailsJob}, - BoxClock, BoxRepository, BoxRng, + job::JobRepositoryExt, queue::SendAccountRecoveryEmailsJob, BoxClock, BoxRepository, BoxRng, }; use mas_templates::{ EmptyContext, FieldError, FormError, FormState, RecoveryStartContext, RecoveryStartFormField, diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index fa8de82f4..d19152331 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -24,7 +24,8 @@ use mas_matrix::BoxHomeserverConnection; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ - job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob}, + job::JobRepositoryExt, + queue::{ProvisionUserJob, VerifyEmailJob}, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; diff --git a/crates/storage-pg/.sqlx/query-0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5.json b/crates/storage-pg/.sqlx/query-0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5.json deleted file mode 100644 index 6488ff09c..000000000 --- a/crates/storage-pg/.sqlx/query-0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.payload,\n queue_jobs.metadata\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "queue_job_id", - "type_info": "Uuid" - }, - { - "ordinal": 1, - "name": "payload", - "type_info": "Jsonb" - }, - { - "ordinal": 2, - "name": "metadata", - "type_info": "Jsonb" - } - ], - "parameters": { - "Left": [ - "TextArray", - "Int8", - "Timestamptz", - "Uuid" - ] - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "0ac1abe7161c0e58d76d8b1e4de293c35ba4503855d52a5b62b6e86b126362f5" -} diff --git a/crates/storage-pg/src/queue/job.rs b/crates/storage-pg/src/queue/job.rs index 47185243e..e2ef1005a 100644 --- a/crates/storage-pg/src/queue/job.rs +++ b/crates/storage-pg/src/queue/job.rs @@ -7,10 +7,7 @@ //! [`QueueJobRepository`]. use async_trait::async_trait; -use mas_storage::{ - queue::{Job, QueueJobRepository, Worker}, - Clock, -}; +use mas_storage::{queue::QueueJobRepository, Clock}; use rand::RngCore; use sqlx::PgConnection; use ulid::Ulid; @@ -76,75 +73,4 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { Ok(()) } - - #[tracing::instrument( - name = "db.queue_job.get_available", - fields( - db.query.text, - ), - skip_all, - err, - )] - async fn get_available( - &mut self, - clock: &dyn Clock, - worker: &Worker, - queues: &[&str], - max_count: usize, - ) -> Result, Self::Error> { - let now = clock.now(); - let max_count = i64::try_from(max_count).unwrap_or(i64::MAX); - let queues: Vec = queues.iter().map(|&s| s.to_owned()).collect(); - sqlx::query!( - r#" - -- We first grab a few jobs that are available, - -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently - -- and we don't get multiple workers grabbing the same jobs - WITH locked_jobs AS ( - SELECT queue_job_id - FROM queue_jobs - WHERE - status = 'available' - AND queue_name = ANY($1) - ORDER BY queue_job_id ASC - LIMIT $2 - FOR UPDATE - SKIP LOCKED - ) - -- then we update the status of those jobs to 'running', returning the job details - UPDATE queue_jobs - SET status = 'running', started_at = $3, started_by = $4 - FROM locked_jobs - WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id - RETURNING - queue_jobs.queue_job_id, - queue_jobs.payload, - queue_jobs.metadata - "#, - &queues, - max_count, - now, - Uuid::from(worker.id), - ) - .traced() - .fetch_all(&mut *self.conn) - .await?; - - todo!() - } - - #[tracing::instrument( - name = "db.queue_job.mark_completed", - fields( - queue_job.id = %job.id, - db.query.text, - ), - skip_all, - err, - )] - async fn mark_completed(&mut self, clock: &dyn Clock, job: Job) -> Result<(), Self::Error> { - let _ = clock; - let _ = job; - todo!() - } } diff --git a/crates/storage/src/job.rs b/crates/storage/src/job.rs index 7b64c8f0b..cb329d6a1 100644 --- a/crates/storage/src/job.rs +++ b/crates/storage/src/job.rs @@ -219,296 +219,3 @@ where .await } } - -mod jobs { - // XXX: Move this somewhere else? - use apalis_core::job::Job; - use mas_data_model::{Device, User, UserEmail, UserRecoverySession}; - use serde::{Deserialize, Serialize}; - use ulid::Ulid; - - /// A job to verify an email address. - #[derive(Serialize, Deserialize, Debug, Clone)] - pub struct VerifyEmailJob { - user_email_id: Ulid, - language: Option, - } - - impl VerifyEmailJob { - /// Create a new job to verify an email address. - #[must_use] - pub fn new(user_email: &UserEmail) -> Self { - Self { - user_email_id: user_email.id, - language: None, - } - } - - /// Set the language to use for the email. - #[must_use] - pub fn with_language(mut self, language: String) -> Self { - self.language = Some(language); - self - } - - /// The language to use for the email. - #[must_use] - pub fn language(&self) -> Option<&str> { - self.language.as_deref() - } - - /// The ID of the email address to verify. - #[must_use] - pub fn user_email_id(&self) -> Ulid { - self.user_email_id - } - } - - impl Job for VerifyEmailJob { - const NAME: &'static str = "verify-email"; - } - - /// A job to provision the user on the homeserver. - #[derive(Serialize, Deserialize, Debug, Clone)] - pub struct ProvisionUserJob { - user_id: Ulid, - set_display_name: Option, - } - - impl ProvisionUserJob { - /// Create a new job to provision the user on the homeserver. - #[must_use] - pub fn new(user: &User) -> Self { - Self { - user_id: user.id, - set_display_name: None, - } - } - - #[doc(hidden)] - #[must_use] - pub fn new_for_id(user_id: Ulid) -> Self { - Self { - user_id, - set_display_name: None, - } - } - - /// Set the display name of the user. - #[must_use] - pub fn set_display_name(mut self, display_name: String) -> Self { - self.set_display_name = Some(display_name); - self - } - - /// Get the display name to be set. - #[must_use] - pub fn display_name_to_set(&self) -> Option<&str> { - self.set_display_name.as_deref() - } - - /// The ID of the user to provision. - #[must_use] - pub fn user_id(&self) -> Ulid { - self.user_id - } - } - - impl Job for ProvisionUserJob { - const NAME: &'static str = "provision-user"; - } - - /// A job to provision a device for a user on the homeserver. - /// - /// This job is deprecated, use the `SyncDevicesJob` instead. It is kept to - /// not break existing jobs in the database. - #[derive(Serialize, Deserialize, Debug, Clone)] - pub struct ProvisionDeviceJob { - user_id: Ulid, - device_id: String, - } - - impl ProvisionDeviceJob { - /// The ID of the user to provision the device for. - #[must_use] - pub fn user_id(&self) -> Ulid { - self.user_id - } - - /// The ID of the device to provision. - #[must_use] - pub fn device_id(&self) -> &str { - &self.device_id - } - } - - impl Job for ProvisionDeviceJob { - const NAME: &'static str = "provision-device"; - } - - /// A job to delete a device for a user on the homeserver. - /// - /// This job is deprecated, use the `SyncDevicesJob` instead. It is kept to - /// not break existing jobs in the database. - #[derive(Serialize, Deserialize, Debug, Clone)] - pub struct DeleteDeviceJob { - user_id: Ulid, - device_id: String, - } - - impl DeleteDeviceJob { - /// Create a new job to delete a device for a user on the homeserver. - #[must_use] - pub fn new(user: &User, device: &Device) -> Self { - Self { - user_id: user.id, - device_id: device.as_str().to_owned(), - } - } - - /// The ID of the user to delete the device for. - #[must_use] - pub fn user_id(&self) -> Ulid { - self.user_id - } - - /// The ID of the device to delete. - #[must_use] - pub fn device_id(&self) -> &str { - &self.device_id - } - } - - impl Job for DeleteDeviceJob { - const NAME: &'static str = "delete-device"; - } - - /// A job which syncs the list of devices of a user with the homeserver - #[derive(Serialize, Deserialize, Debug, Clone)] - pub struct SyncDevicesJob { - user_id: Ulid, - } - - impl SyncDevicesJob { - /// Create a new job to sync the list of devices of a user with the - /// homeserver - #[must_use] - pub fn new(user: &User) -> Self { - Self { user_id: user.id } - } - - /// The ID of the user to sync the devices for - #[must_use] - pub fn user_id(&self) -> Ulid { - self.user_id - } - } - - impl Job for SyncDevicesJob { - const NAME: &'static str = "sync-devices"; - } - - /// A job to deactivate and lock a user - #[derive(Serialize, Deserialize, Debug, Clone)] - pub struct DeactivateUserJob { - user_id: Ulid, - hs_erase: bool, - } - - impl DeactivateUserJob { - /// Create a new job to deactivate and lock a user - /// - /// # Parameters - /// - /// * `user` - The user to deactivate - /// * `hs_erase` - Whether to erase the user from the homeserver - #[must_use] - pub fn new(user: &User, hs_erase: bool) -> Self { - Self { - user_id: user.id, - hs_erase, - } - } - - /// The ID of the user to deactivate - #[must_use] - pub fn user_id(&self) -> Ulid { - self.user_id - } - - /// Whether to erase the user from the homeserver - #[must_use] - pub fn hs_erase(&self) -> bool { - self.hs_erase - } - } - - impl Job for DeactivateUserJob { - const NAME: &'static str = "deactivate-user"; - } - - /// A job to reactivate a user - #[derive(Serialize, Deserialize, Debug, Clone)] - pub struct ReactivateUserJob { - user_id: Ulid, - } - - impl ReactivateUserJob { - /// Create a new job to reactivate a user - /// - /// # Parameters - /// - /// * `user` - The user to reactivate - #[must_use] - pub fn new(user: &User) -> Self { - Self { user_id: user.id } - } - - /// The ID of the user to reactivate - #[must_use] - pub fn user_id(&self) -> Ulid { - self.user_id - } - } - - impl Job for ReactivateUserJob { - const NAME: &'static str = "reactivate-user"; - } - - /// Send account recovery emails - #[derive(Serialize, Deserialize, Debug, Clone)] - pub struct SendAccountRecoveryEmailsJob { - user_recovery_session_id: Ulid, - } - - impl SendAccountRecoveryEmailsJob { - /// Create a new job to send account recovery emails - /// - /// # Parameters - /// - /// * `user_recovery_session` - The user recovery session to send the - /// email for - /// * `language` - The locale to send the email in - #[must_use] - pub fn new(user_recovery_session: &UserRecoverySession) -> Self { - Self { - user_recovery_session_id: user_recovery_session.id, - } - } - - /// The ID of the user recovery session to send the email for - #[must_use] - pub fn user_recovery_session_id(&self) -> Ulid { - self.user_recovery_session_id - } - } - - impl Job for SendAccountRecoveryEmailsJob { - const NAME: &'static str = "send-account-recovery-email"; - } -} - -pub use self::jobs::{ - DeactivateUserJob, DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, ReactivateUserJob, - SendAccountRecoveryEmailsJob, SyncDevicesJob, VerifyEmailJob, -}; diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs index a96780ff5..c5ec3f4f4 100644 --- a/crates/storage/src/queue/job.rs +++ b/crates/storage/src/queue/job.rs @@ -6,30 +6,67 @@ //! Repository to interact with jobs in the job queue use async_trait::async_trait; +use opentelemetry::trace::TraceContextExt; use rand_core::RngCore; +use serde::{Deserialize, Serialize}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use ulid::Ulid; -use super::Worker; use crate::{repository_impl, Clock}; -enum JobState { - /// The job is available to be picked up by a worker - Available, +/// Represents a job in the job queue +pub struct Job { + /// The ID of the job + pub id: Ulid, + + /// The payload of the job + pub payload: serde_json::Value, + + /// Arbitrary metadata about the job + pub metadata: JobMetadata, +} - /// The job is currently being processed by a worker - Running, +/// Metadata stored alongside the job +#[derive(Serialize, Deserialize, Default)] +pub struct JobMetadata { + #[serde(default)] + trace_id: String, - /// The job has been completed - Completed, + #[serde(default)] + span_id: String, - /// The worker running the job was lost - Lost, + #[serde(default)] + trace_flags: u8, } -/// Represents a job in the job queue -pub struct Job { - /// The ID of the job - pub id: Ulid, +impl JobMetadata { + fn new(span_context: &opentelemetry::trace::SpanContext) -> Self { + Self { + trace_id: span_context.trace_id().to_string(), + span_id: span_context.span_id().to_string(), + trace_flags: span_context.trace_flags().to_u8(), + } + } + + /// Get the [`opentelemetry::trace::SpanContext`] from this [`JobMetadata`] + #[must_use] + pub fn span_context(&self) -> opentelemetry::trace::SpanContext { + use opentelemetry::trace::{SpanContext, SpanId, TraceFlags, TraceId, TraceState}; + SpanContext::new( + TraceId::from_hex(&self.trace_id).unwrap_or(TraceId::INVALID), + SpanId::from_hex(&self.span_id).unwrap_or(SpanId::INVALID), + TraceFlags::new(self.trace_flags), + // Trace context is remote, as it comes from another service/from the database + true, + TraceState::NONE, + ) + } +} + +/// A trait that represents a job which can be inserted into a queue +pub trait InsertableJob: Serialize + Send { + /// The name of the queue this job belongs to + const QUEUE_NAME: &'static str; } /// A [`QueueJobRepository`] is used to schedule jobs to be executed by a @@ -60,46 +97,72 @@ pub trait QueueJobRepository: Send + Sync { payload: serde_json::Value, metadata: serde_json::Value, ) -> Result<(), Self::Error>; +} - /// Get and lock a batch of jobs that are ready to be executed. - /// This will transition them to a [`JobState::Running`] state. - /// - /// # Errors - /// - /// Returns an error if the underlying repository fails. - async fn get_available( +repository_impl!(QueueJobRepository: + async fn schedule( &mut self, + rng: &mut (dyn RngCore + Send), clock: &dyn Clock, - worker: &Worker, - queues: &[&str], - max_count: usize, - ) -> Result, Self::Error>; + queue_name: &str, + payload: serde_json::Value, + metadata: serde_json::Value, + ) -> Result<(), Self::Error>; +); - /// Mark the given job as completed. +/// Extension trait for [`QueueJobRepository`] to help adding a job to the queue +/// through the [`InsertableJob`] trait. This isn't in the +/// [`QueueJobRepository`] trait to keep it object safe. +#[async_trait] +pub trait QueueJobRepositoryExt: QueueJobRepository { + /// Schedule a job to be executed as soon as possible by a worker. + /// + /// # Parameters + /// + /// * `rng` - The random number generator used to generate a new job ID + /// * `clock` - The clock used to generate timestamps + /// * `job` - The job to schedule /// /// # Errors /// /// Returns an error if the underlying repository fails. - async fn mark_completed(&mut self, clock: &dyn Clock, job: Job) -> Result<(), Self::Error>; -} - -repository_impl!(QueueJobRepository: - async fn schedule( + async fn schedule_job( &mut self, rng: &mut (dyn RngCore + Send), clock: &dyn Clock, - queue_name: &str, - payload: serde_json::Value, - metadata: serde_json::Value, + job: J, ) -> Result<(), Self::Error>; +} - async fn get_available( +#[async_trait] +impl QueueJobRepositoryExt for T +where + T: QueueJobRepository, +{ + #[tracing::instrument( + name = "db.queue_job.schedule_job", + fields( + queue_job.queue_name = J::QUEUE_NAME, + ), + skip_all, + )] + async fn schedule_job( &mut self, + rng: &mut (dyn RngCore + Send), clock: &dyn Clock, - worker: &Worker, - queues: &[&str], - max_count: usize, - ) -> Result, Self::Error>; + job: J, + ) -> Result<(), Self::Error> { + // Grab the span context from the current span + let span = tracing::Span::current(); + let ctx = span.context(); + let span = ctx.span(); + let span_context = span.span_context(); - async fn mark_completed(&mut self, clock: &dyn Clock, job: Job) -> Result<(), Self::Error>; -); + let metadata = JobMetadata::new(span_context); + let metadata = serde_json::to_value(metadata).expect("Could not serialize metadata"); + + let payload = serde_json::to_value(job).expect("Could not serialize job"); + self.schedule(rng, clock, J::QUEUE_NAME, payload, metadata) + .await + } +} diff --git a/crates/storage/src/queue/mod.rs b/crates/storage/src/queue/mod.rs index a9757aed1..d02bee5fd 100644 --- a/crates/storage/src/queue/mod.rs +++ b/crates/storage/src/queue/mod.rs @@ -6,9 +6,11 @@ //! A module containing repositories for the job queue mod job; +mod tasks; mod worker; pub use self::{ - job::{Job, QueueJobRepository}, + job::{InsertableJob, Job, JobMetadata, QueueJobRepository, QueueJobRepositoryExt}, + tasks::*, worker::{QueueWorkerRepository, Worker}, }; diff --git a/crates/storage/src/queue/tasks.rs b/crates/storage/src/queue/tasks.rs new file mode 100644 index 000000000..cfc17b471 --- /dev/null +++ b/crates/storage/src/queue/tasks.rs @@ -0,0 +1,330 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use mas_data_model::{Device, User, UserEmail, UserRecoverySession}; +use serde::{Deserialize, Serialize}; +use ulid::Ulid; + +use super::InsertableJob; + +/// A job to verify an email address. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct VerifyEmailJob { + user_email_id: Ulid, + language: Option, +} + +impl VerifyEmailJob { + /// Create a new job to verify an email address. + #[must_use] + pub fn new(user_email: &UserEmail) -> Self { + Self { + user_email_id: user_email.id, + language: None, + } + } + + /// Set the language to use for the email. + #[must_use] + pub fn with_language(mut self, language: String) -> Self { + self.language = Some(language); + self + } + + /// The language to use for the email. + #[must_use] + pub fn language(&self) -> Option<&str> { + self.language.as_deref() + } + + /// The ID of the email address to verify. + #[must_use] + pub fn user_email_id(&self) -> Ulid { + self.user_email_id + } +} + +// Implemented for compatibility +impl apalis_core::job::Job for VerifyEmailJob { + const NAME: &'static str = "verify-email"; +} + +impl InsertableJob for VerifyEmailJob { + const QUEUE_NAME: &'static str = "verify-email"; +} + +/// A job to provision the user on the homeserver. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ProvisionUserJob { + user_id: Ulid, + set_display_name: Option, +} + +impl ProvisionUserJob { + /// Create a new job to provision the user on the homeserver. + #[must_use] + pub fn new(user: &User) -> Self { + Self { + user_id: user.id, + set_display_name: None, + } + } + + #[doc(hidden)] + #[must_use] + pub fn new_for_id(user_id: Ulid) -> Self { + Self { + user_id, + set_display_name: None, + } + } + + /// Set the display name of the user. + #[must_use] + pub fn set_display_name(mut self, display_name: String) -> Self { + self.set_display_name = Some(display_name); + self + } + + /// Get the display name to be set. + #[must_use] + pub fn display_name_to_set(&self) -> Option<&str> { + self.set_display_name.as_deref() + } + + /// The ID of the user to provision. + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } +} + +// Implemented for compatibility +impl apalis_core::job::Job for ProvisionUserJob { + const NAME: &'static str = "provision-user"; +} + +impl InsertableJob for ProvisionUserJob { + const QUEUE_NAME: &'static str = "provision-user"; +} + +/// A job to provision a device for a user on the homeserver. +/// +/// This job is deprecated, use the `SyncDevicesJob` instead. It is kept to +/// not break existing jobs in the database. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ProvisionDeviceJob { + user_id: Ulid, + device_id: String, +} + +impl ProvisionDeviceJob { + /// The ID of the user to provision the device for. + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } + + /// The ID of the device to provision. + #[must_use] + pub fn device_id(&self) -> &str { + &self.device_id + } +} + +// Implemented for compatibility with older versions +impl apalis_core::job::Job for ProvisionDeviceJob { + const NAME: &'static str = "provision-device"; +} + +impl InsertableJob for ProvisionDeviceJob { + const QUEUE_NAME: &'static str = "provision-device"; +} + +/// A job to delete a device for a user on the homeserver. +/// +/// This job is deprecated, use the `SyncDevicesJob` instead. It is kept to +/// not break existing jobs in the database. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct DeleteDeviceJob { + user_id: Ulid, + device_id: String, +} + +impl DeleteDeviceJob { + /// Create a new job to delete a device for a user on the homeserver. + #[must_use] + pub fn new(user: &User, device: &Device) -> Self { + Self { + user_id: user.id, + device_id: device.as_str().to_owned(), + } + } + + /// The ID of the user to delete the device for. + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } + + /// The ID of the device to delete. + #[must_use] + pub fn device_id(&self) -> &str { + &self.device_id + } +} + +// Implemented for compatibility with older versions +impl apalis_core::job::Job for DeleteDeviceJob { + const NAME: &'static str = "delete-device"; +} + +impl InsertableJob for DeleteDeviceJob { + const QUEUE_NAME: &'static str = "delete-device"; +} + +/// A job which syncs the list of devices of a user with the homeserver +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct SyncDevicesJob { + user_id: Ulid, +} + +impl SyncDevicesJob { + /// Create a new job to sync the list of devices of a user with the + /// homeserver + #[must_use] + pub fn new(user: &User) -> Self { + Self { user_id: user.id } + } + + /// The ID of the user to sync the devices for + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } +} + +// Implemented for compatibility with older versions +impl apalis_core::job::Job for SyncDevicesJob { + const NAME: &'static str = "sync-devices"; +} + +impl InsertableJob for SyncDevicesJob { + const QUEUE_NAME: &'static str = "sync-devices"; +} + +/// A job to deactivate and lock a user +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct DeactivateUserJob { + user_id: Ulid, + hs_erase: bool, +} + +impl DeactivateUserJob { + /// Create a new job to deactivate and lock a user + /// + /// # Parameters + /// + /// * `user` - The user to deactivate + /// * `hs_erase` - Whether to erase the user from the homeserver + #[must_use] + pub fn new(user: &User, hs_erase: bool) -> Self { + Self { + user_id: user.id, + hs_erase, + } + } + + /// The ID of the user to deactivate + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } + + /// Whether to erase the user from the homeserver + #[must_use] + pub fn hs_erase(&self) -> bool { + self.hs_erase + } +} + +// Implemented for compatibility with older versions +impl apalis_core::job::Job for DeactivateUserJob { + const NAME: &'static str = "deactivate-user"; +} + +impl InsertableJob for DeactivateUserJob { + const QUEUE_NAME: &'static str = "deactivate-user"; +} + +/// A job to reactivate a user +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ReactivateUserJob { + user_id: Ulid, +} + +impl ReactivateUserJob { + /// Create a new job to reactivate a user + /// + /// # Parameters + /// + /// * `user` - The user to reactivate + #[must_use] + pub fn new(user: &User) -> Self { + Self { user_id: user.id } + } + + /// The ID of the user to reactivate + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } +} + +// Implemented for compatibility with older versions +impl apalis_core::job::Job for ReactivateUserJob { + const NAME: &'static str = "reactivate-user"; +} + +impl InsertableJob for ReactivateUserJob { + const QUEUE_NAME: &'static str = "reactivate-user"; +} + +/// Send account recovery emails +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct SendAccountRecoveryEmailsJob { + user_recovery_session_id: Ulid, +} + +impl SendAccountRecoveryEmailsJob { + /// Create a new job to send account recovery emails + /// + /// # Parameters + /// + /// * `user_recovery_session` - The user recovery session to send the email + /// for + /// * `language` - The locale to send the email in + #[must_use] + pub fn new(user_recovery_session: &UserRecoverySession) -> Self { + Self { + user_recovery_session_id: user_recovery_session.id, + } + } + + /// The ID of the user recovery session to send the email for + #[must_use] + pub fn user_recovery_session_id(&self) -> Ulid { + self.user_recovery_session_id + } +} + +// Implemented for compatibility with older versions +impl apalis_core::job::Job for SendAccountRecoveryEmailsJob { + const NAME: &'static str = "send-account-recovery-email"; +} + +impl InsertableJob for SendAccountRecoveryEmailsJob { + const QUEUE_NAME: &'static str = "send-account-recovery-email"; +} diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index 90d0aac35..a16ca29dc 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -9,7 +9,7 @@ use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor use chrono::Duration; use mas_email::{Address, Mailbox}; use mas_i18n::locale; -use mas_storage::job::{JobWithSpanContext, VerifyEmailJob}; +use mas_storage::{job::JobWithSpanContext, queue::VerifyEmailJob}; use mas_templates::{EmailVerificationContext, TemplateContext}; use rand::{distributions::Uniform, Rng}; use tracing::info; diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index ad1cb591e..3cc09b272 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -12,11 +12,9 @@ use mas_data_model::Device; use mas_matrix::ProvisionRequest; use mas_storage::{ compat::CompatSessionFilter, - job::{ - DeleteDeviceJob, JobRepositoryExt as _, JobWithSpanContext, ProvisionDeviceJob, - ProvisionUserJob, SyncDevicesJob, - }, + job::{JobRepositoryExt as _, JobWithSpanContext}, oauth2::OAuth2SessionFilter, + queue::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, SyncDevicesJob}, user::{UserEmailRepository, UserRepository}, Pagination, RepositoryAccess, }; diff --git a/crates/tasks/src/recovery.rs b/crates/tasks/src/recovery.rs index 142f9f165..79f469b06 100644 --- a/crates/tasks/src/recovery.rs +++ b/crates/tasks/src/recovery.rs @@ -9,7 +9,8 @@ use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor use mas_email::{Address, Mailbox}; use mas_i18n::DataLocale; use mas_storage::{ - job::{JobWithSpanContext, SendAccountRecoveryEmailsJob}, + job::JobWithSpanContext, + queue::SendAccountRecoveryEmailsJob, user::{UserEmailFilter, UserRecoveryRepository}, Pagination, RepositoryAccess, }; diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index 76f81d20e..b3d062bb4 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -8,8 +8,9 @@ use anyhow::Context; use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; use mas_storage::{ compat::CompatSessionFilter, - job::{DeactivateUserJob, JobWithSpanContext, ReactivateUserJob}, + job::JobWithSpanContext, oauth2::OAuth2SessionFilter, + queue::{DeactivateUserJob, ReactivateUserJob}, user::{BrowserSessionFilter, UserRepository}, RepositoryAccess, }; From fdc3d9a16a39c6777778b4998c01a3b1953c8b89 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 15 Oct 2024 14:35:01 +0200 Subject: [PATCH 09/17] Schedule jobs through the new queue --- Cargo.lock | 124 +--------- crates/cli/src/commands/manage.rs | 24 +- crates/cli/src/commands/server.rs | 28 +-- crates/cli/src/commands/worker.rs | 27 +-- crates/handlers/src/admin/v1/users/add.rs | 9 +- .../handlers/src/admin/v1/users/deactivate.rs | 34 +-- crates/handlers/src/compat/logout.rs | 10 +- .../src/graphql/mutations/compat_session.rs | 9 +- .../src/graphql/mutations/oauth2_session.rs | 8 +- crates/handlers/src/graphql/mutations/user.rs | 13 +- .../src/graphql/mutations/user_email.rs | 29 +-- crates/handlers/src/oauth2/revoke.rs | 8 +- crates/handlers/src/upstream_oauth2/link.rs | 5 +- .../handlers/src/views/account/emails/add.rs | 13 +- .../src/views/account/emails/verify.rs | 10 +- .../handlers/src/views/recovery/progress.rs | 11 +- crates/handlers/src/views/recovery/start.rs | 11 +- crates/handlers/src/views/register.rs | 15 +- ...11eb91698aa1f2d5d146cffbb7aea8d77467b.json | 16 -- crates/storage-pg/src/job.rs | 67 ------ crates/storage-pg/src/lib.rs | 1 - crates/storage-pg/src/repository.rs | 14 +- crates/storage/Cargo.toml | 8 +- crates/storage/src/job.rs | 221 ------------------ crates/storage/src/lib.rs | 1 - crates/storage/src/queue/tasks.rs | 40 ---- crates/storage/src/repository.rs | 28 ++- crates/tasks/Cargo.toml | 6 - crates/tasks/src/lib.rs | 82 +------ deny.toml | 5 - 30 files changed, 180 insertions(+), 697 deletions(-) delete mode 100644 crates/storage-pg/.sqlx/query-359a00f6667b5b1fef616b0c18e11eb91698aa1f2d5d146cffbb7aea8d77467b.json delete mode 100644 crates/storage-pg/src/job.rs delete mode 100644 crates/storage/src/job.rs diff --git a/Cargo.lock b/Cargo.lock index d8d66a228..71b98659a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -194,44 +194,6 @@ version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" -[[package]] -name = "apalis-core" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1deb48475efcdece1f23a0553209ee842f264c2a5e9bcc4928bfa6a15a044cde" -dependencies = [ - "async-stream", - "async-trait", - "chrono", - "futures", - "graceful-shutdown", - "http", - "log", - "pin-project-lite", - "serde", - "strum 0.25.0", - "thiserror 1.0.69", - "tokio", - "tower 0.4.13", - "tracing", - "ulid", -] - -[[package]] -name = "apalis-cron" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43310b7e0132f9520b09224fb6faafb32eec82a672aa79c09e46b5b488ed505b" -dependencies = [ - "apalis-core", - "async-stream", - "chrono", - "cron", - "futures", - "tokio", - "tower 0.4.13", -] - [[package]] name = "arbitrary" version = "1.3.2" @@ -391,7 +353,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "strum 0.26.3", + "strum", "syn", "thiserror 1.0.69", ] @@ -638,7 +600,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tower 0.5.1", + "tower", "tower-layer", "tower-service", "tracing", @@ -685,7 +647,7 @@ dependencies = [ "multer", "pin-project-lite", "serde", - "tower 0.5.1", + "tower", "tower-layer", "tower-service", ] @@ -1352,17 +1314,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "cron" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f8c3e73077b4b4a6ab1ea5047c37c57aee77657bc8ecd6f29b0af082d0b0c07" -dependencies = [ - "chrono", - "nom", - "once_cell", -] - [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -2138,17 +2089,6 @@ dependencies = [ "spinning_top", ] -[[package]] -name = "graceful-shutdown" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3effbaf774a1da3462925bb182ccf975c284cf46edca5569ea93420a657af484" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - [[package]] name = "group" version = "0.13.0" @@ -3186,7 +3126,7 @@ dependencies = [ "serde_with", "thiserror 2.0.3", "tokio", - "tower 0.5.1", + "tower", "tracing", "ulid", "url", @@ -3255,7 +3195,7 @@ dependencies = [ "sqlx", "tokio", "tokio-util", - "tower 0.5.1", + "tower", "tower-http", "tracing", "tracing-appender", @@ -3395,7 +3335,7 @@ dependencies = [ "time", "tokio", "tokio-util", - "tower 0.5.1", + "tower", "tower-http", "tracing", "tracing-subscriber", @@ -3420,7 +3360,7 @@ dependencies = [ "reqwest", "rustls-platform-verifier", "tokio", - "tower 0.5.1", + "tower", "tower-http", "tracing", "tracing-opentelemetry", @@ -3567,7 +3507,7 @@ dependencies = [ "tokio-rustls", "tokio-test", "tokio-util", - "tower 0.5.1", + "tower", "tower-http", "tracing", "tracing-subscriber", @@ -3599,7 +3539,7 @@ dependencies = [ "serde", "serde_json", "thiserror 2.0.3", - "tower 0.5.1", + "tower", "tracing", "url", "urlencoding", @@ -3683,7 +3623,6 @@ dependencies = [ name = "mas-storage" version = "0.12.0" dependencies = [ - "apalis-core", "async-trait", "chrono", "futures-util", @@ -3734,8 +3673,6 @@ name = "mas-tasks" version = "0.12.0" dependencies = [ "anyhow", - "apalis-core", - "apalis-cron", "async-stream", "async-trait", "chrono", @@ -3759,7 +3696,7 @@ dependencies = [ "thiserror 2.0.3", "tokio", "tokio-util", - "tower 0.5.1", + "tower", "tracing", "tracing-opentelemetry", "ulid", @@ -3804,7 +3741,7 @@ dependencies = [ "opentelemetry-http", "opentelemetry-semantic-conventions", "pin-project-lite", - "tower 0.5.1", + "tower", "tracing", "tracing-opentelemetry", ] @@ -6117,35 +6054,13 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" -[[package]] -name = "strum" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" -dependencies = [ - "strum_macros 0.25.3", -] - [[package]] name = "strum" version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ - "strum_macros 0.26.4", -] - -[[package]] -name = "strum_macros" -version = "0.25.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "rustversion", - "syn", + "strum_macros", ] [[package]] @@ -6484,21 +6399,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "tower" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "pin-project", - "pin-project-lite", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "tower" version = "0.5.1" diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index e0891ecec..c7379a960 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -20,9 +20,11 @@ use mas_matrix::HomeserverConnection; use mas_matrix_synapse::SynapseConnection; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository}, - job::JobRepositoryExt, oauth2::OAuth2SessionFilter, - queue::{DeactivateUserJob, ProvisionUserJob, ReactivateUserJob, SyncDevicesJob}, + queue::{ + DeactivateUserJob, ProvisionUserJob, QueueJobRepositoryExt as _, ReactivateUserJob, + SyncDevicesJob, + }, user::{BrowserSessionFilter, UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, RepositoryAccess, SystemClock, }; @@ -365,7 +367,7 @@ impl Options { let id = id.into(); info!(user.id = %id, "Scheduling provisioning job"); let job = ProvisionUserJob::new_for_id(id); - repo.job().schedule_job(job).await?; + repo.queue_job().schedule_job(&mut rng, &clock, job).await?; } repo.into_inner().commit().await?; @@ -428,7 +430,9 @@ impl Options { // Schedule a job to sync the devices of the user with the homeserver warn!("Scheduling job to sync devices for the user"); - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; let txn = repo.into_inner(); if dry_run { @@ -466,8 +470,8 @@ impl Options { if deactivate { warn!(%user.id, "Scheduling user deactivation"); - repo.job() - .schedule_job(DeactivateUserJob::new(&user, false)) + repo.queue_job() + .schedule_job(&mut rng, &clock, DeactivateUserJob::new(&user, false)) .await?; } @@ -490,8 +494,8 @@ impl Options { .context("User not found")?; warn!(%user.id, "User scheduling user reactivation"); - repo.job() - .schedule_job(ReactivateUserJob::new(&user)) + repo.queue_job() + .schedule_job(&mut rng, &clock, ReactivateUserJob::new(&user)) .await?; repo.into_inner().commit().await?; @@ -974,7 +978,9 @@ impl UserCreationRequest<'_> { provision_job = provision_job.set_display_name(display_name); } - repo.job().schedule_job(provision_job).await?; + repo.queue_job() + .schedule_job(rng, clock, provision_job) + .await?; Ok(user) } diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 70834ccba..0b0efdc31 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -19,10 +19,6 @@ use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; use mas_storage::SystemClock; use mas_storage_pg::MIGRATOR; -use rand::{ - distributions::{Alphanumeric, DistString}, - thread_rng, -}; use sqlx::migrate::Migrate; use tracing::{info, info_span, warn, Instrument}; @@ -161,13 +157,8 @@ impl Options { let mailer = mailer_from_config(&config.email, &templates)?; mailer.test_connection().await?; - #[allow(clippy::disallowed_methods)] - let mut rng = thread_rng(); - let worker_name = Alphanumeric.sample_string(&mut rng, 10); - - info!(worker_name, "Starting task worker"); - let monitor = mas_tasks::init( - &worker_name, + info!("Starting task worker"); + mas_tasks::init( &pool, &mailer, homeserver_connection.clone(), @@ -176,21 +167,6 @@ impl Options { shutdown.task_tracker(), ) .await?; - - // XXX: The monitor from apalis is a bit annoying to use for graceful shutdowns, - // ideally we'd just give it a cancellation token - let shutdown_future = shutdown.soft_shutdown_token().cancelled_owned(); - shutdown.task_tracker().spawn(async move { - if let Err(e) = monitor - .run_with_signal(async move { - shutdown_future.await; - Ok(()) - }) - .await - { - tracing::error!(error = &e as &dyn std::error::Error, "Task worker failed"); - } - }); } let listeners_config = config.http.listeners.clone(); diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index c58605a1b..3d976d46e 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -11,10 +11,6 @@ use figment::Figment; use mas_config::{AppConfig, ConfigurationSection}; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; -use rand::{ - distributions::{Alphanumeric, DistString}, - thread_rng, -}; use tracing::{info, info_span}; use crate::{ @@ -71,13 +67,8 @@ impl Options { drop(config); - #[allow(clippy::disallowed_methods)] - let mut rng = thread_rng(); - let worker_name = Alphanumeric.sample_string(&mut rng, 10); - - info!(worker_name, "Starting task scheduler"); - let monitor = mas_tasks::init( - &worker_name, + info!("Starting task scheduler"); + mas_tasks::init( &pool, &mailer, conn, @@ -87,20 +78,6 @@ impl Options { ) .await?; - // XXX: The monitor from apalis is a bit annoying to use for graceful shutdowns, - // ideally we'd just give it a cancellation token - let shutdown_future = shutdown.soft_shutdown_token().cancelled_owned(); - shutdown.task_tracker().spawn(async move { - if let Err(e) = monitor - .run_with_signal(async move { - shutdown_future.await; - Ok(()) - }) - .await - { - tracing::error!(error = &e as &dyn std::error::Error, "Task worker failed"); - } - }); span.exit(); shutdown.run().await; diff --git a/crates/handlers/src/admin/v1/users/add.rs b/crates/handlers/src/admin/v1/users/add.rs index 33fc299f2..81b17a1ed 100644 --- a/crates/handlers/src/admin/v1/users/add.rs +++ b/crates/handlers/src/admin/v1/users/add.rs @@ -8,7 +8,10 @@ use aide::{transform::TransformOperation, NoApi, OperationIo}; use axum::{extract::State, response::IntoResponse, Json}; use hyper::StatusCode; use mas_matrix::BoxHomeserverConnection; -use mas_storage::{job::JobRepositoryExt, queue::ProvisionUserJob, BoxRng}; +use mas_storage::{ + queue::{ProvisionUserJob, QueueJobRepositoryExt as _}, + BoxRng, +}; use schemars::JsonSchema; use serde::Deserialize; use tracing::warn; @@ -161,8 +164,8 @@ pub async fn handler( let user = repo.user().add(&mut rng, &clock, params.username).await?; - repo.job() - .schedule_job(ProvisionUserJob::new(&user)) + repo.queue_job() + .schedule_job(&mut rng, &clock, ProvisionUserJob::new(&user)) .await?; repo.save().await?; diff --git a/crates/handlers/src/admin/v1/users/deactivate.rs b/crates/handlers/src/admin/v1/users/deactivate.rs index 091116c8b..ad09a7ca2 100644 --- a/crates/handlers/src/admin/v1/users/deactivate.rs +++ b/crates/handlers/src/admin/v1/users/deactivate.rs @@ -4,10 +4,13 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use aide::{transform::TransformOperation, OperationIo}; +use aide::{transform::TransformOperation, NoApi, OperationIo}; use axum::{response::IntoResponse, Json}; use hyper::StatusCode; -use mas_storage::{job::JobRepositoryExt, queue::DeactivateUserJob}; +use mas_storage::{ + queue::{DeactivateUserJob, QueueJobRepositoryExt as _}, + BoxRng, +}; use tracing::info; use ulid::Ulid; @@ -69,6 +72,7 @@ pub async fn handler( CallContext { mut repo, clock, .. }: CallContext, + NoApi(mut rng): NoApi, id: UlidPathParam, ) -> Result>, RouteError> { let id = *id; @@ -83,8 +87,8 @@ pub async fn handler( } info!("Scheduling deactivation of user {}", user.id); - repo.job() - .schedule_job(DeactivateUserJob::new(&user, true)) + repo.queue_job() + .schedule_job(&mut rng, &clock, DeactivateUserJob::new(&user, true)) .await?; repo.save().await?; @@ -133,11 +137,12 @@ mod tests { // It should have scheduled a deactivation job for the user // XXX: we don't have a good way to look for the deactivation job - let job: Json = - sqlx::query_scalar("SELECT job FROM apalis.jobs WHERE job_type = 'deactivate-user'") - .fetch_one(&pool) - .await - .expect("Deactivation job to be scheduled"); + let job: Json = sqlx::query_scalar( + "SELECT payload FROM queue_jobs WHERE queue_name = 'deactivate-user'", + ) + .fetch_one(&pool) + .await + .expect("Deactivation job to be scheduled"); assert_eq!(job["user_id"], serde_json::json!(user.id)); } @@ -174,11 +179,12 @@ mod tests { // It should have scheduled a deactivation job for the user // XXX: we don't have a good way to look for the deactivation job - let job: Json = - sqlx::query_scalar("SELECT job FROM apalis.jobs WHERE job_type = 'deactivate-user'") - .fetch_one(&pool) - .await - .expect("Deactivation job to be scheduled"); + let job: Json = sqlx::query_scalar( + "SELECT payload FROM queue_jobs WHERE queue_name = 'deactivate-user'", + ) + .fetch_one(&pool) + .await + .expect("Deactivation job to be scheduled"); assert_eq!(job["user_id"], serde_json::json!(user.id)); } diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index df41a4ad7..8ef2dd95b 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -12,9 +12,8 @@ use mas_axum_utils::sentry::SentryEventID; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - job::JobRepositoryExt, - queue::SyncDevicesJob, - BoxClock, BoxRepository, Clock, RepositoryAccess, + queue::{QueueJobRepositoryExt as _, SyncDevicesJob}, + BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; use thiserror::Error; @@ -66,6 +65,7 @@ impl IntoResponse for RouteError { #[tracing::instrument(name = "handlers.compat.logout.post", skip_all, err)] pub(crate) async fn post( clock: BoxClock, + mut rng: BoxRng, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, maybe_authorization: Option>>, @@ -105,7 +105,9 @@ pub(crate) async fn post( .ok_or(RouteError::InvalidAuthorization)?; // Schedule a job to sync the devices of the user with the homeserver - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; repo.compat_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/graphql/mutations/compat_session.rs b/crates/handlers/src/graphql/mutations/compat_session.rs index bea5d90d5..c13c1e559 100644 --- a/crates/handlers/src/graphql/mutations/compat_session.rs +++ b/crates/handlers/src/graphql/mutations/compat_session.rs @@ -7,7 +7,9 @@ use anyhow::Context as _; use async_graphql::{Context, Enum, InputObject, Object, ID}; use mas_storage::{ - compat::CompatSessionRepository, job::JobRepositoryExt, queue::SyncDevicesJob, RepositoryAccess, + compat::CompatSessionRepository, + queue::{QueueJobRepositoryExt as _, SyncDevicesJob}, + RepositoryAccess, }; use crate::graphql::{ @@ -70,6 +72,7 @@ impl CompatSessionMutations { input: EndCompatSessionInput, ) -> Result { let state = ctx.state(); + let mut rng = state.rng(); let compat_session_id = NodeType::CompatSession.extract_ulid(&input.compat_session_id)?; let requester = ctx.requester(); @@ -92,7 +95,9 @@ impl CompatSessionMutations { .context("Could not load user")?; // Schedule a job to sync the devices of the user with the homeserver - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; let session = repo.compat_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/graphql/mutations/oauth2_session.rs b/crates/handlers/src/graphql/mutations/oauth2_session.rs index 0dcbd894a..0b1dbe669 100644 --- a/crates/handlers/src/graphql/mutations/oauth2_session.rs +++ b/crates/handlers/src/graphql/mutations/oauth2_session.rs @@ -9,12 +9,11 @@ use async_graphql::{Context, Description, Enum, InputObject, Object, ID}; use chrono::Duration; use mas_data_model::{Device, TokenType}; use mas_storage::{ - job::JobRepositoryExt, oauth2::{ OAuth2AccessTokenRepository, OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, - queue::SyncDevicesJob, + queue::{QueueJobRepositoryExt as _, SyncDevicesJob}, user::UserRepository, RepositoryAccess, }; @@ -218,6 +217,7 @@ impl OAuth2SessionMutations { let mut repo = state.repository().await?; let clock = state.clock(); + let mut rng = state.rng(); let session = repo.oauth2_session().lookup(oauth2_session_id).await?; let Some(session) = session else { @@ -236,7 +236,9 @@ impl OAuth2SessionMutations { .context("Could not load user")?; // Schedule a job to sync the devices of the user with the homeserver - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; } let session = repo.oauth2_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/graphql/mutations/user.rs b/crates/handlers/src/graphql/mutations/user.rs index 9cfe22545..04d9cc9b3 100644 --- a/crates/handlers/src/graphql/mutations/user.rs +++ b/crates/handlers/src/graphql/mutations/user.rs @@ -7,8 +7,7 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Enum, InputObject, Object, ID}; use mas_storage::{ - job::JobRepositoryExt, - queue::{DeactivateUserJob, ProvisionUserJob}, + queue::{DeactivateUserJob, ProvisionUserJob, QueueJobRepositoryExt as _}, user::UserRepository, }; use tracing::{info, warn}; @@ -399,8 +398,8 @@ impl UserMutations { let user = repo.user().add(&mut rng, &clock, input.username).await?; - repo.job() - .schedule_job(ProvisionUserJob::new(&user)) + repo.queue_job() + .schedule_job(&mut rng, &clock, ProvisionUserJob::new(&user)) .await?; repo.save().await?; @@ -415,6 +414,8 @@ impl UserMutations { input: LockUserInput, ) -> Result { let state = ctx.state(); + let clock = state.clock(); + let mut rng = state.rng(); let requester = ctx.requester(); if !requester.is_admin() { @@ -436,8 +437,8 @@ impl UserMutations { if deactivate { info!("Scheduling deactivation of user {}", user.id); - repo.job() - .schedule_job(DeactivateUserJob::new(&user, deactivate)) + repo.queue_job() + .schedule_job(&mut rng, &clock, DeactivateUserJob::new(&user, deactivate)) .await?; } diff --git a/crates/handlers/src/graphql/mutations/user_email.rs b/crates/handlers/src/graphql/mutations/user_email.rs index 057b18919..08604f652 100644 --- a/crates/handlers/src/graphql/mutations/user_email.rs +++ b/crates/handlers/src/graphql/mutations/user_email.rs @@ -7,8 +7,7 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Enum, InputObject, Object, ID}; use mas_storage::{ - job::JobRepositoryExt, - queue::{ProvisionUserJob, VerifyEmailJob}, + queue::{ProvisionUserJob, QueueJobRepositoryExt as _, VerifyEmailJob}, user::{UserEmailRepository, UserRepository}, RepositoryAccess, }; @@ -377,6 +376,8 @@ impl UserEmailMutations { let state = ctx.state(); let id = NodeType::User.extract_ulid(&input.user_id)?; let requester = ctx.requester(); + let clock = state.clock(); + let mut rng = state.rng(); if !requester.is_owner_or_admin(&UserId(id)) { return Err(async_graphql::Error::new("Unauthorized")); @@ -428,9 +429,6 @@ impl UserEmailMutations { let (added, mut user_email) = if let Some(user_email) = existing_user_email { (false, user_email) } else { - let clock = state.clock(); - let mut rng = state.rng(); - let user_email = repo .user_email() .add(&mut rng, &clock, &user, input.email) @@ -448,8 +446,8 @@ impl UserEmailMutations { .await?; } else { // TODO: figure out the locale - repo.job() - .schedule_job(VerifyEmailJob::new(&user_email)) + repo.queue_job() + .schedule_job(&mut rng, &clock, VerifyEmailJob::new(&user_email)) .await?; } } @@ -471,6 +469,8 @@ impl UserEmailMutations { input: SendVerificationEmailInput, ) -> Result { let state = ctx.state(); + let clock = state.clock(); + let mut rng = state.rng(); let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?; let requester = ctx.requester(); @@ -490,8 +490,8 @@ impl UserEmailMutations { let needs_verification = user_email.confirmed_at.is_none(); if needs_verification { // TODO: figure out the locale - repo.job() - .schedule_job(VerifyEmailJob::new(&user_email)) + repo.queue_job() + .schedule_job(&mut rng, &clock, VerifyEmailJob::new(&user_email)) .await?; } @@ -516,6 +516,7 @@ impl UserEmailMutations { let requester = ctx.requester(); let clock = state.clock(); + let mut rng = state.rng(); let mut repo = state.repository().await?; let user_email = repo @@ -568,8 +569,8 @@ impl UserEmailMutations { .mark_as_verified(&clock, user_email) .await?; - repo.job() - .schedule_job(ProvisionUserJob::new(&user)) + repo.queue_job() + .schedule_job(&mut rng, &clock, ProvisionUserJob::new(&user)) .await?; repo.save().await?; @@ -587,6 +588,8 @@ impl UserEmailMutations { let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?; let requester = ctx.requester(); + let mut rng = state.rng(); + let clock = state.clock(); let mut repo = state.repository().await?; let user_email = repo.user_email().lookup(user_email_id).await?; @@ -617,8 +620,8 @@ impl UserEmailMutations { repo.user_email().remove(user_email.clone()).await?; // Schedule a job to update the user - repo.job() - .schedule_job(ProvisionUserJob::new(&user)) + repo.queue_job() + .schedule_job(&mut rng, &clock, ProvisionUserJob::new(&user)) .await?; repo.save().await?; diff --git a/crates/handlers/src/oauth2/revoke.rs b/crates/handlers/src/oauth2/revoke.rs index 83abf8147..557614c5e 100644 --- a/crates/handlers/src/oauth2/revoke.rs +++ b/crates/handlers/src/oauth2/revoke.rs @@ -14,7 +14,8 @@ use mas_data_model::TokenType; use mas_iana::oauth::OAuthTokenTypeHint; use mas_keystore::Encrypter; use mas_storage::{ - job::JobRepositoryExt, queue::SyncDevicesJob, BoxClock, BoxRepository, RepositoryAccess, + queue::{QueueJobRepositoryExt as _, SyncDevicesJob}, + BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -109,6 +110,7 @@ impl From for RouteError { )] pub(crate) async fn post( clock: BoxClock, + mut rng: BoxRng, State(http_client): State, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, @@ -208,7 +210,9 @@ pub(crate) async fn post( .ok_or(RouteError::UnknownToken)?; // Schedule a job to sync the devices of the user with the homeserver - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; } // Now that we checked everything, we can end the session. diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index f7d6e2bf9..b58da8a3b 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -23,8 +23,7 @@ use mas_matrix::BoxHomeserverConnection; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ - job::JobRepositoryExt, - queue::ProvisionUserJob, + queue::{ProvisionUserJob, QueueJobRepositoryExt as _}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserRepository}, BoxClock, BoxRepository, BoxRng, RepositoryAccess, @@ -797,7 +796,7 @@ pub(crate) async fn post( job = job.set_display_name(name); } - repo.job().schedule_job(job).await?; + repo.queue_job().schedule_job(&mut rng, &clock, job).await?; // If we have an email, add it to the user if let Some(email) = email { diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index dbb5dba1d..fbdbccdd5 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -17,8 +17,9 @@ use mas_data_model::SiteConfig; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ - job::JobRepositoryExt, queue::VerifyEmailJob, user::UserEmailRepository, BoxClock, - BoxRepository, BoxRng, + queue::{QueueJobRepositoryExt as _, VerifyEmailJob}, + user::UserEmailRepository, + BoxClock, BoxRepository, BoxRng, }; use mas_templates::{EmailAddContext, ErrorContext, TemplateContext, Templates}; use serde::Deserialize; @@ -136,8 +137,12 @@ pub(crate) async fn post( // If the email was not confirmed, send a confirmation email & redirect to the // verify page let next = if user_email.confirmed_at.is_none() { - repo.job() - .schedule_job(VerifyEmailJob::new(&user_email).with_language(locale.to_string())) + repo.queue_job() + .schedule_job( + &mut rng, + &clock, + VerifyEmailJob::new(&user_email).with_language(locale.to_string()), + ) .await?; let next = mas_router::AccountVerifyEmail::new(user_email.id); diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index a25b1028b..518177c7b 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -16,8 +16,9 @@ use mas_axum_utils::{ }; use mas_router::UrlBuilder; use mas_storage::{ - job::JobRepositoryExt, queue::ProvisionUserJob, user::UserEmailRepository, BoxClock, - BoxRepository, BoxRng, RepositoryAccess, + queue::{ProvisionUserJob, QueueJobRepositoryExt as _}, + user::UserEmailRepository, + BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; @@ -93,6 +94,7 @@ pub(crate) async fn get( )] pub(crate) async fn post( clock: BoxClock, + mut rng: BoxRng, mut repo: BoxRepository, cookie_jar: CookieJar, State(url_builder): State, @@ -140,8 +142,8 @@ pub(crate) async fn post( .mark_as_verified(&clock, user_email) .await?; - repo.job() - .schedule_job(ProvisionUserJob::new(&session.user)) + repo.queue_job() + .schedule_job(&mut rng, &clock, ProvisionUserJob::new(&session.user)) .await?; repo.save().await?; diff --git a/crates/handlers/src/views/recovery/progress.rs b/crates/handlers/src/views/recovery/progress.rs index c0a5519df..3c2a178fa 100644 --- a/crates/handlers/src/views/recovery/progress.rs +++ b/crates/handlers/src/views/recovery/progress.rs @@ -18,7 +18,8 @@ use mas_axum_utils::{ use mas_data_model::SiteConfig; use mas_router::UrlBuilder; use mas_storage::{ - job::JobRepositoryExt, queue::SendAccountRecoveryEmailsJob, BoxClock, BoxRepository, BoxRng, + queue::{QueueJobRepositoryExt as _, SendAccountRecoveryEmailsJob}, + BoxClock, BoxRepository, BoxRng, }; use mas_templates::{EmptyContext, RecoveryProgressContext, TemplateContext, Templates}; use ulid::Ulid; @@ -135,8 +136,12 @@ pub(crate) async fn post( } // Schedule a new batch of emails - repo.job() - .schedule_job(SendAccountRecoveryEmailsJob::new(&recovery_session)) + repo.queue_job() + .schedule_job( + &mut rng, + &clock, + SendAccountRecoveryEmailsJob::new(&recovery_session), + ) .await?; repo.save().await?; diff --git a/crates/handlers/src/views/recovery/start.rs b/crates/handlers/src/views/recovery/start.rs index e9cbc758f..9b5a7a4d5 100644 --- a/crates/handlers/src/views/recovery/start.rs +++ b/crates/handlers/src/views/recovery/start.rs @@ -21,7 +21,8 @@ use mas_axum_utils::{ use mas_data_model::{SiteConfig, UserAgent}; use mas_router::UrlBuilder; use mas_storage::{ - job::JobRepositoryExt, queue::SendAccountRecoveryEmailsJob, BoxClock, BoxRepository, BoxRng, + queue::{QueueJobRepositoryExt as _, SendAccountRecoveryEmailsJob}, + BoxClock, BoxRepository, BoxRng, }; use mas_templates::{ EmptyContext, FieldError, FormError, FormState, RecoveryStartContext, RecoveryStartFormField, @@ -144,8 +145,12 @@ pub(crate) async fn post( ) .await?; - repo.job() - .schedule_job(SendAccountRecoveryEmailsJob::new(&session)) + repo.queue_job() + .schedule_job( + &mut rng, + &clock, + SendAccountRecoveryEmailsJob::new(&session), + ) .await?; repo.save().await?; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index d19152331..0eaa99801 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -24,8 +24,7 @@ use mas_matrix::BoxHomeserverConnection; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ - job::JobRepositoryExt, - queue::{ProvisionUserJob, VerifyEmailJob}, + queue::{ProvisionUserJob, QueueJobRepositoryExt as _, VerifyEmailJob}, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; @@ -295,12 +294,16 @@ pub(crate) async fn post( .authenticate_with_password(&mut rng, &clock, &session, &user_password) .await?; - repo.job() - .schedule_job(VerifyEmailJob::new(&user_email).with_language(locale.to_string())) + repo.queue_job() + .schedule_job( + &mut rng, + &clock, + VerifyEmailJob::new(&user_email).with_language(locale.to_string()), + ) .await?; - repo.job() - .schedule_job(ProvisionUserJob::new(&user)) + repo.queue_job() + .schedule_job(&mut rng, &clock, ProvisionUserJob::new(&user)) .await?; repo.save().await?; diff --git a/crates/storage-pg/.sqlx/query-359a00f6667b5b1fef616b0c18e11eb91698aa1f2d5d146cffbb7aea8d77467b.json b/crates/storage-pg/.sqlx/query-359a00f6667b5b1fef616b0c18e11eb91698aa1f2d5d146cffbb7aea8d77467b.json deleted file mode 100644 index 941ae4366..000000000 --- a/crates/storage-pg/.sqlx/query-359a00f6667b5b1fef616b0c18e11eb91698aa1f2d5d146cffbb7aea8d77467b.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO apalis.jobs (job, id, job_type)\n VALUES ($1::json, $2::text, $3::text)\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Json", - "Text", - "Text" - ] - }, - "nullable": [] - }, - "hash": "359a00f6667b5b1fef616b0c18e11eb91698aa1f2d5d146cffbb7aea8d77467b" -} diff --git a/crates/storage-pg/src/job.rs b/crates/storage-pg/src/job.rs deleted file mode 100644 index f0630458a..000000000 --- a/crates/storage-pg/src/job.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! A module containing the PostgreSQL implementation of the [`JobRepository`]. - -use async_trait::async_trait; -use mas_storage::job::{JobId, JobRepository, JobSubmission}; -use sqlx::PgConnection; - -use crate::{DatabaseError, ExecuteExt}; - -/// An implementation of [`JobRepository`] for a PostgreSQL connection. -pub struct PgJobRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgJobRepository<'c> { - /// Create a new [`PgJobRepository`] from an active PostgreSQL connection. - #[must_use] - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[async_trait] -impl JobRepository for PgJobRepository<'_> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.job.schedule_submission", - skip_all, - fields( - db.query.text, - job.id, - job.name = submission.name(), - ), - err, - )] - async fn schedule_submission( - &mut self, - submission: JobSubmission, - ) -> Result { - // XXX: This does not use the clock nor the rng - let id = JobId::new(); - tracing::Span::current().record("job.id", tracing::field::display(&id)); - - let res = sqlx::query!( - r#" - INSERT INTO apalis.jobs (job, id, job_type) - VALUES ($1::json, $2::text, $3::text) - "#, - submission.payload(), - id.to_string(), - submission.name(), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - Ok(id) - } -} diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index e16303278..3312d73b8 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -164,7 +164,6 @@ use sqlx::migrate::Migrator; pub mod app_session; pub mod compat; -pub mod job; pub mod oauth2; pub mod queue; pub mod upstream_oauth2; diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 99580467c..b5c2b68b2 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -13,7 +13,6 @@ use mas_storage::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, }, - job::JobRepository, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, @@ -34,13 +33,12 @@ use crate::{ PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, PgCompatSsoLoginRepository, }, - job::PgJobRepository, oauth2::{ PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository, PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, }, - queue::worker::PgQueueWorkerRepository, + queue::{job::PgQueueJobRepository, worker::PgQueueWorkerRepository}, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, @@ -261,13 +259,15 @@ where Box::new(PgCompatRefreshTokenRepository::new(self.conn.as_mut())) } - fn job<'c>(&'c mut self) -> Box + 'c> { - Box::new(PgJobRepository::new(self.conn.as_mut())) - } - fn queue_worker<'c>( &'c mut self, ) -> Box + 'c> { Box::new(PgQueueWorkerRepository::new(self.conn.as_mut())) } + + fn queue_job<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgQueueJobRepository::new(self.conn.as_mut())) + } } diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index 53b62c916..22d209df0 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -14,18 +14,16 @@ workspace = true [dependencies] async-trait.workspace = true chrono.workspace = true -thiserror.workspace = true futures-util.workspace = true - -apalis-core = { version = "0.4.9", features = ["tokio-comp"] } opentelemetry.workspace = true rand_core = "0.6.4" serde.workspace = true serde_json.workspace = true -tracing.workspace = true +thiserror.workspace = true tracing-opentelemetry.workspace = true -url.workspace = true +tracing.workspace = true ulid.workspace = true +url.workspace = true oauth2-types.workspace = true mas-data-model.workspace = true diff --git a/crates/storage/src/job.rs b/crates/storage/src/job.rs deleted file mode 100644 index cb329d6a1..000000000 --- a/crates/storage/src/job.rs +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Repository to schedule persistent jobs. - -use std::{num::ParseIntError, ops::Deref}; - -pub use apalis_core::job::{Job, JobId}; -use async_trait::async_trait; -use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tracing_opentelemetry::OpenTelemetrySpanExt; - -use crate::repository_impl; - -/// A job submission to be scheduled through the repository. -pub struct JobSubmission { - name: &'static str, - payload: Value, -} - -#[derive(Serialize, Deserialize)] -struct SerializableSpanContext { - trace_id: String, - span_id: String, - trace_flags: u8, -} - -impl From<&SpanContext> for SerializableSpanContext { - fn from(value: &SpanContext) -> Self { - Self { - trace_id: value.trace_id().to_string(), - span_id: value.span_id().to_string(), - trace_flags: value.trace_flags().to_u8(), - } - } -} - -impl TryFrom<&SerializableSpanContext> for SpanContext { - type Error = ParseIntError; - - fn try_from(value: &SerializableSpanContext) -> Result { - Ok(SpanContext::new( - TraceId::from_hex(&value.trace_id)?, - SpanId::from_hex(&value.span_id)?, - TraceFlags::new(value.trace_flags), - // XXX: is that fine? - true, - TraceState::default(), - )) - } -} - -/// A wrapper for [`Job`] which adds the span context in the payload. -#[derive(Serialize, Deserialize)] -pub struct JobWithSpanContext { - #[serde(skip_serializing_if = "Option::is_none")] - span_context: Option, - - #[serde(flatten)] - payload: T, -} - -impl From for JobWithSpanContext { - fn from(payload: J) -> Self { - Self { - span_context: None, - payload, - } - } -} - -impl Job for JobWithSpanContext { - const NAME: &'static str = J::NAME; -} - -impl Deref for JobWithSpanContext { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.payload - } -} - -impl JobWithSpanContext { - /// Get the span context of the job. - /// - /// # Returns - /// - /// Returns [`None`] if the job has no span context, or if the span context - /// is invalid. - #[must_use] - pub fn span_context(&self) -> Option { - self.span_context - .as_ref() - .and_then(|ctx| ctx.try_into().ok()) - } -} - -impl JobSubmission { - /// Create a new job submission out of a [`Job`]. - /// - /// # Panics - /// - /// Panics if the job cannot be serialized. - #[must_use] - pub fn new(job: J) -> Self { - let payload = serde_json::to_value(job).expect("Could not serialize job"); - - Self { - name: J::NAME, - payload, - } - } - - /// Create a new job submission out of a [`Job`] and a [`SpanContext`]. - /// - /// # Panics - /// - /// Panics if the job cannot be serialized. - #[must_use] - pub fn new_with_span_context(job: J, span_context: &SpanContext) -> Self { - // Serialize the span context alongside the job. - let span_context = SerializableSpanContext::from(span_context); - - Self::new(JobWithSpanContext { - payload: job, - span_context: Some(span_context), - }) - } - - /// The name of the job. - #[must_use] - pub fn name(&self) -> &'static str { - self.name - } - - /// The payload of the job. - #[must_use] - pub fn payload(&self) -> &Value { - &self.payload - } -} - -/// A [`JobRepository`] is used to schedule jobs to be executed by a worker. -#[async_trait] -pub trait JobRepository: Send + Sync { - /// The error type returned by the repository. - type Error; - - /// Schedule a job submission to be executed at a later time. - /// - /// # Parameters - /// - /// * `submission` - The job to schedule. - /// - /// # Errors - /// - /// Returns [`Self::Error`] if the underlying repository fails - async fn schedule_submission( - &mut self, - submission: JobSubmission, - ) -> Result; -} - -repository_impl!(JobRepository: - async fn schedule_submission(&mut self, submission: JobSubmission) -> Result; -); - -/// An extension trait for [`JobRepository`] to schedule jobs directly. -#[async_trait] -pub trait JobRepositoryExt { - /// The error type returned by the repository. - type Error; - - /// Schedule a job to be executed at a later time. - /// - /// # Parameters - /// - /// * `job` - The job to schedule. - /// - /// # Errors - /// - /// Returns [`Self::Error`] if the underlying repository fails - async fn schedule_job( - &mut self, - job: J, - ) -> Result; -} - -#[async_trait] -impl JobRepositoryExt for T -where - T: JobRepository + ?Sized, -{ - type Error = T::Error; - - #[tracing::instrument( - name = "db.job.schedule_job", - skip_all, - fields( - job.name = J::NAME, - ), - )] - async fn schedule_job( - &mut self, - job: J, - ) -> Result { - let span = tracing::Span::current(); - let ctx = span.context(); - let span = ctx.span(); - let span_context = span.span_context(); - - self.schedule_submission(JobSubmission::new_with_span_context(job, span_context)) - .await - } -} diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 30dc553de..cd0d646c3 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -118,7 +118,6 @@ mod utils; pub mod app_session; pub mod compat; -pub mod job; pub mod oauth2; pub mod queue; pub mod upstream_oauth2; diff --git a/crates/storage/src/queue/tasks.rs b/crates/storage/src/queue/tasks.rs index cfc17b471..a2fe85be4 100644 --- a/crates/storage/src/queue/tasks.rs +++ b/crates/storage/src/queue/tasks.rs @@ -46,11 +46,6 @@ impl VerifyEmailJob { } } -// Implemented for compatibility -impl apalis_core::job::Job for VerifyEmailJob { - const NAME: &'static str = "verify-email"; -} - impl InsertableJob for VerifyEmailJob { const QUEUE_NAME: &'static str = "verify-email"; } @@ -101,11 +96,6 @@ impl ProvisionUserJob { } } -// Implemented for compatibility -impl apalis_core::job::Job for ProvisionUserJob { - const NAME: &'static str = "provision-user"; -} - impl InsertableJob for ProvisionUserJob { const QUEUE_NAME: &'static str = "provision-user"; } @@ -134,11 +124,6 @@ impl ProvisionDeviceJob { } } -// Implemented for compatibility with older versions -impl apalis_core::job::Job for ProvisionDeviceJob { - const NAME: &'static str = "provision-device"; -} - impl InsertableJob for ProvisionDeviceJob { const QUEUE_NAME: &'static str = "provision-device"; } @@ -176,11 +161,6 @@ impl DeleteDeviceJob { } } -// Implemented for compatibility with older versions -impl apalis_core::job::Job for DeleteDeviceJob { - const NAME: &'static str = "delete-device"; -} - impl InsertableJob for DeleteDeviceJob { const QUEUE_NAME: &'static str = "delete-device"; } @@ -206,11 +186,6 @@ impl SyncDevicesJob { } } -// Implemented for compatibility with older versions -impl apalis_core::job::Job for SyncDevicesJob { - const NAME: &'static str = "sync-devices"; -} - impl InsertableJob for SyncDevicesJob { const QUEUE_NAME: &'static str = "sync-devices"; } @@ -250,11 +225,6 @@ impl DeactivateUserJob { } } -// Implemented for compatibility with older versions -impl apalis_core::job::Job for DeactivateUserJob { - const NAME: &'static str = "deactivate-user"; -} - impl InsertableJob for DeactivateUserJob { const QUEUE_NAME: &'static str = "deactivate-user"; } @@ -283,11 +253,6 @@ impl ReactivateUserJob { } } -// Implemented for compatibility with older versions -impl apalis_core::job::Job for ReactivateUserJob { - const NAME: &'static str = "reactivate-user"; -} - impl InsertableJob for ReactivateUserJob { const QUEUE_NAME: &'static str = "reactivate-user"; } @@ -320,11 +285,6 @@ impl SendAccountRecoveryEmailsJob { } } -// Implemented for compatibility with older versions -impl apalis_core::job::Job for SendAccountRecoveryEmailsJob { - const NAME: &'static str = "send-account-recovery-email"; -} - impl InsertableJob for SendAccountRecoveryEmailsJob { const QUEUE_NAME: &'static str = "send-account-recovery-email"; } diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 55d19d281..161ef05e3 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -13,12 +13,11 @@ use crate::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, }, - job::JobRepository, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, - queue::QueueWorkerRepository, + queue::{QueueJobRepository, QueueWorkerRepository}, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, @@ -190,11 +189,11 @@ pub trait RepositoryAccess: Send { &'c mut self, ) -> Box + 'c>; - /// Get a [`JobRepository`] - fn job<'c>(&'c mut self) -> Box + 'c>; - /// Get a [`QueueWorkerRepository`] fn queue_worker<'c>(&'c mut self) -> Box + 'c>; + + /// Get a [`QueueJobRepository`] + fn queue_job<'c>(&'c mut self) -> Box + 'c>; } /// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and @@ -209,13 +208,12 @@ mod impls { CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, }, - job::JobRepository, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, - queue::QueueWorkerRepository, + queue::{QueueJobRepository, QueueWorkerRepository}, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, @@ -407,15 +405,15 @@ mod impls { )) } - fn job<'c>(&'c mut self) -> Box + 'c> { - Box::new(MapErr::new(self.inner.job(), &mut self.mapper)) - } - fn queue_worker<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new(self.inner.queue_worker(), &mut self.mapper)) } + + fn queue_job<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.queue_job(), &mut self.mapper)) + } } impl RepositoryAccess for Box { @@ -535,14 +533,14 @@ mod impls { (**self).compat_refresh_token() } - fn job<'c>(&'c mut self) -> Box + 'c> { - (**self).job() - } - fn queue_worker<'c>( &'c mut self, ) -> Box + 'c> { (**self).queue_worker() } + + fn queue_job<'c>(&'c mut self) -> Box + 'c> { + (**self).queue_job() + } } } diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 763b0f5fc..7a18ca0aa 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -13,12 +13,6 @@ workspace = true [dependencies] anyhow.workspace = true -apalis-core = { version = "0.4.9", features = [ - "extensions", - "tokio-comp", - "storage", -] } -apalis-cron = "0.4.9" async-stream = "0.3.6" async-trait.workspace = true chrono.workspace = true diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index ac85eba14..e56a082c7 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -4,9 +4,10 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +#![allow(dead_code)] + use std::sync::Arc; -use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monitor::Monitor}; use mas_email::Mailer; use mas_matrix::HomeserverConnection; use mas_router::UrlBuilder; @@ -16,18 +17,15 @@ use new_queue::QueueRunnerError; use rand::SeedableRng; use sqlx::{Pool, Postgres}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; -use tracing::debug; - -use crate::storage::PostgresStorageFactory; -mod database; -mod email; -mod matrix; +// mod database; +// mod email; +// mod matrix; mod new_queue; -mod recovery; -mod storage; -mod user; -mod utils; +// mod recovery; +// mod storage; +// mod user; +// mod utils; #[derive(Clone)] struct State { @@ -55,10 +53,6 @@ impl State { } } - pub fn inject(&self) -> Extension { - Extension(self.clone()) - } - pub fn pool(&self) -> &Pool { &self.pool } @@ -95,58 +89,19 @@ impl State { } } -trait JobContextExt { - fn state(&self) -> State; -} - -impl JobContextExt for apalis_core::context::JobContext { - fn state(&self) -> State { - self.data_opt::() - .expect("state not injected in job context") - .clone() - } -} - -/// Helper macro to build a storage-backed worker. -macro_rules! build { - ($job:ty => $fn:ident, $suffix:expr, $state:expr, $factory:expr) => {{ - let storage = $factory.build(); - let worker_name = format!( - "{job}-{suffix}", - job = <$job as ::apalis_core::job::Job>::NAME, - suffix = $suffix - ); - - let builder = ::apalis_core::builder::WorkerBuilder::new(worker_name) - .layer($state.inject()) - .layer(crate::utils::trace_layer()) - .layer(crate::utils::metrics_layer()); - - let builder = ::apalis_core::storage::builder::WithStorage::with_storage_config( - builder, - storage, - |c| c.fetch_interval(std::time::Duration::from_secs(1)), - ); - ::apalis_core::builder::WorkerFactory::build(builder, ::apalis_core::job_fn::job_fn($fn)) - }}; -} - -pub(crate) use build; - /// Initialise the workers. /// /// # Errors /// /// This function can fail if the database connection fails. pub async fn init( - name: &str, pool: &Pool, mailer: &Mailer, homeserver: impl HomeserverConnection + 'static, url_builder: UrlBuilder, cancellation_token: CancellationToken, task_tracker: &TaskTracker, -) -> Result, QueueRunnerError> { +) -> Result<(), QueueRunnerError> { let state = State::new( pool.clone(), SystemClock::default(), @@ -154,21 +109,6 @@ pub async fn init( homeserver, url_builder, ); - let factory = PostgresStorageFactory::new(pool.clone()); - let monitor = Monitor::new().executor(TokioExecutor::new()); - let monitor = self::database::register(name, monitor, &state); - let monitor = self::email::register(name, monitor, &state, &factory); - let monitor = self::matrix::register(name, monitor, &state, &factory); - let monitor = self::user::register(name, monitor, &state, &factory); - let monitor = self::recovery::register(name, monitor, &state, &factory); - // TODO: we might want to grab the join handle here - // TODO: this error isn't right, I just want that to compile - factory - .listen() - .await - .map_err(QueueRunnerError::SetupListener)?; - debug!(?monitor, "workers registered"); - let mut worker = self::new_queue::QueueWorker::new(state, cancellation_token).await?; task_tracker.spawn(async move { @@ -180,5 +120,5 @@ pub async fn init( } }); - Ok(monitor) + Ok(()) } diff --git a/deny.toml b/deny.toml index 571923c35..0adab1ff9 100644 --- a/deny.toml +++ b/deny.toml @@ -69,13 +69,8 @@ skip = [ { name = "heck", version = "0.4.1" }, # wasmtime -> cranelift is depending on this old version { name = "itertools", version = "0.12.1" }, - # apalis-core depends on this old version - { name = "strum", version = "0.25.0" }, - { name = "strum_macros", version = "0.25.0" }, # For some reason, axum-core depends on this old version, even though axum is on the new one { name = "sync_wrapper", version = "0.1.2" }, - # `apalis` depends on this old version of tower - { name = "tower", version = "0.4.13" }, ] skip-tree = [] From 4994761dbf098520d4b93b1ad3fcd72ce242309b Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 15 Oct 2024 11:59:32 +0200 Subject: [PATCH 10/17] WIP: job consumption --- crates/tasks/src/new_queue.rs | 44 ++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index 571f8591b..f90b72011 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -3,11 +3,18 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::collections::HashMap; + +use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; -use mas_storage::{queue::Worker, Clock, RepositoryAccess, RepositoryError}; +use mas_storage::{ + queue::{InsertableJob, Job, Worker}, + Clock, RepositoryAccess, RepositoryError, +}; use mas_storage_pg::{DatabaseError, PgRepository}; use rand::{distributions::Uniform, Rng}; use rand_chacha::ChaChaRng; +use serde::de::DeserializeOwned; use sqlx::{ postgres::{PgAdvisoryLock, PgListener}, Acquire, Either, @@ -17,6 +24,30 @@ use tokio_util::sync::CancellationToken; use crate::State; +pub trait FromJob { + fn from_job(job: &Job) -> Result + where + Self: Sized; +} + +impl FromJob for T +where + T: DeserializeOwned, +{ + fn from_job(job: &Job) -> Result { + serde_json::from_value(job.payload.clone()).map_err(Into::into) + } +} + +#[async_trait] +pub trait RunnableJob: FromJob + Send + 'static { + async fn run(&self, state: &State) -> Result<(), anyhow::Error>; +} + +fn box_runnable_job(job: T) -> Box { + Box::new(job) +} + #[derive(Debug, Error)] pub enum QueueRunnerError { #[error("Failed to setup listener")] @@ -48,6 +79,8 @@ pub enum QueueRunnerError { const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900); const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100); +type JobFactory = Box Box + Send>; + pub struct QueueWorker { rng: ChaChaRng, clock: Box, @@ -56,6 +89,7 @@ pub struct QueueWorker { am_i_leader: bool, last_heartbeat: DateTime, cancellation_token: CancellationToken, + factories: HashMap<&'static str, JobFactory>, } impl QueueWorker { @@ -105,9 +139,17 @@ impl QueueWorker { am_i_leader: false, last_heartbeat: now, cancellation_token, + factories: HashMap::new(), }) } + pub fn register_handler(&mut self) -> &mut Self { + // TODO: error handling + let factory = |job: &Job| box_runnable_job(T::from_job(job).unwrap()); + self.factories.insert(T::QUEUE_NAME, Box::new(factory)); + self + } + pub async fn run(&mut self) -> Result<(), QueueRunnerError> { while !self.cancellation_token.is_cancelled() { self.run_loop().await?; From 880349c8650dc87bf3c322c94b87c93706a76edf Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 31 Oct 2024 17:38:43 +0100 Subject: [PATCH 11/17] Actually consume jobs --- ...b411aa9f15e7beccfd6212787c3452d35d061.json | 43 ++ ...ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json | 15 + crates/storage-pg/src/queue/job.rs | 130 +++++- crates/storage/src/queue/job.rs | 48 ++- crates/tasks/src/email.rs | 162 ++++---- crates/tasks/src/lib.rs | 20 +- crates/tasks/src/matrix.rs | 390 ++++++++--------- crates/tasks/src/new_queue.rs | 208 +++++++++- crates/tasks/src/recovery.rs | 198 +++++---- crates/tasks/src/storage/from_row.rs | 70 ---- crates/tasks/src/storage/mod.rs | 14 - crates/tasks/src/storage/postgres.rs | 391 ------------------ crates/tasks/src/user.rs | 211 +++++----- crates/tasks/src/utils.rs | 91 ---- 14 files changed, 907 insertions(+), 1084 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json create mode 100644 crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json delete mode 100644 crates/tasks/src/storage/from_row.rs delete mode 100644 crates/tasks/src/storage/mod.rs delete mode 100644 crates/tasks/src/storage/postgres.rs delete mode 100644 crates/tasks/src/utils.rs diff --git a/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json b/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json new file mode 100644 index 000000000..67f1ad132 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json @@ -0,0 +1,43 @@ +{ + "db_name": "PostgreSQL", + "query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.queue_name,\n queue_jobs.payload,\n queue_jobs.metadata\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "queue_job_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "queue_name", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "payload", + "type_info": "Jsonb" + }, + { + "ordinal": 3, + "name": "metadata", + "type_info": "Jsonb" + } + ], + "parameters": { + "Left": [ + "TextArray", + "Int8", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061" +} diff --git a/crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json b/crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json new file mode 100644 index 000000000..407258ab4 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_jobs\n SET status = 'completed', completed_at = $1\n WHERE queue_job_id = $2 AND status = 'running'\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27" +} diff --git a/crates/storage-pg/src/queue/job.rs b/crates/storage-pg/src/queue/job.rs index e2ef1005a..02ceed793 100644 --- a/crates/storage-pg/src/queue/job.rs +++ b/crates/storage-pg/src/queue/job.rs @@ -7,13 +7,16 @@ //! [`QueueJobRepository`]. use async_trait::async_trait; -use mas_storage::{queue::QueueJobRepository, Clock}; +use mas_storage::{ + queue::{Job, QueueJobRepository, Worker}, + Clock, +}; use rand::RngCore; use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{DatabaseError, ExecuteExt}; +use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt}; /// An implementation of [`QueueJobRepository`] for a PostgreSQL connection. pub struct PgQueueJobRepository<'c> { @@ -29,6 +32,37 @@ impl<'c> PgQueueJobRepository<'c> { } } +struct JobReservationResult { + queue_job_id: Uuid, + queue_name: String, + payload: serde_json::Value, + metadata: serde_json::Value, +} + +impl TryFrom for Job { + type Error = DatabaseInconsistencyError; + + fn try_from(value: JobReservationResult) -> Result { + let id = value.queue_job_id.into(); + let queue_name = value.queue_name; + let payload = value.payload; + + let metadata = serde_json::from_value(value.metadata).map_err(|e| { + DatabaseInconsistencyError::on("queue_jobs") + .column("metadata") + .row(id) + .source(e) + })?; + + Ok(Self { + id, + queue_name, + payload, + metadata, + }) + } +} + #[async_trait] impl QueueJobRepository for PgQueueJobRepository<'_> { type Error = DatabaseError; @@ -73,4 +107,96 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { Ok(()) } + + #[tracing::instrument( + name = "db.queue_job.reserve", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn reserve( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + count: usize, + ) -> Result, Self::Error> { + let now = clock.now(); + let max_count = i64::try_from(count).unwrap_or(i64::MAX); + let queues: Vec = queues.iter().map(|&s| s.to_owned()).collect(); + let results = sqlx::query_as!( + JobReservationResult, + r#" + -- We first grab a few jobs that are available, + -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently + -- and we don't get multiple workers grabbing the same jobs + WITH locked_jobs AS ( + SELECT queue_job_id + FROM queue_jobs + WHERE + status = 'available' + AND queue_name = ANY($1) + ORDER BY queue_job_id ASC + LIMIT $2 + FOR UPDATE + SKIP LOCKED + ) + -- then we update the status of those jobs to 'running', returning the job details + UPDATE queue_jobs + SET status = 'running', started_at = $3, started_by = $4 + FROM locked_jobs + WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id + RETURNING + queue_jobs.queue_job_id, + queue_jobs.queue_name, + queue_jobs.payload, + queue_jobs.metadata + "#, + &queues, + max_count, + now, + Uuid::from(worker.id), + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let jobs = results + .into_iter() + .map(TryFrom::try_from) + .collect::, _>>()?; + + Ok(jobs) + } + + #[tracing::instrument( + name = "db.queue_job.mark_as_completed", + skip_all, + fields( + db.query.text, + job.id = %id, + ), + err, + )] + async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error> { + let now = clock.now(); + let res = sqlx::query!( + r#" + UPDATE queue_jobs + SET status = 'completed', completed_at = $1 + WHERE queue_job_id = $2 AND status = 'running' + "#, + now, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } } diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs index c5ec3f4f4..13df586d7 100644 --- a/crates/storage/src/queue/job.rs +++ b/crates/storage/src/queue/job.rs @@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize}; use tracing_opentelemetry::OpenTelemetrySpanExt; use ulid::Ulid; +use super::Worker; use crate::{repository_impl, Clock}; /// Represents a job in the job queue @@ -19,6 +20,9 @@ pub struct Job { /// The ID of the job pub id: Ulid, + /// The queue on which the job was placed + pub queue_name: String, + /// The payload of the job pub payload: serde_json::Value, @@ -27,7 +31,7 @@ pub struct Job { } /// Metadata stored alongside the job -#[derive(Serialize, Deserialize, Default)] +#[derive(Serialize, Deserialize, Default, Clone, Debug)] pub struct JobMetadata { #[serde(default)] trace_id: String, @@ -97,6 +101,38 @@ pub trait QueueJobRepository: Send + Sync { payload: serde_json::Value, metadata: serde_json::Value, ) -> Result<(), Self::Error>; + + /// Reserve multiple jobs from multiple queues + /// + /// # Parameters + /// + /// * `clock` - The clock used to generate timestamps + /// * `worker` - The worker that is reserving the jobs + /// * `queues` - The queues to reserve jobs from + /// * `count` - The number of jobs to reserve + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn reserve( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + count: usize, + ) -> Result, Self::Error>; + + /// Mark a job as completed + /// + /// # Parameters + /// + /// * `clock` - The clock used to generate timestamps + /// * `job` - The job to mark as completed + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; } repository_impl!(QueueJobRepository: @@ -108,6 +144,16 @@ repository_impl!(QueueJobRepository: payload: serde_json::Value, metadata: serde_json::Value, ) -> Result<(), Self::Error>; + + async fn reserve( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + count: usize, + ) -> Result, Self::Error>; + + async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; ); /// Extension trait for [`QueueJobRepository`] to help adding a job to the queue diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index a16ca29dc..3afbab8ce 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -5,97 +5,87 @@ // Please see LICENSE in the repository root for full details. use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use chrono::Duration; use mas_email::{Address, Mailbox}; use mas_i18n::locale; -use mas_storage::{job::JobWithSpanContext, queue::VerifyEmailJob}; +use mas_storage::queue::VerifyEmailJob; use mas_templates::{EmailVerificationContext, TemplateContext}; use rand::{distributions::Uniform, Rng}; use tracing::info; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; - -#[tracing::instrument( - name = "job.verify_email", - fields(user_email.id = %job.user_email_id()), - skip_all, - err(Debug), -)] -async fn verify_email( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let mut repo = state.repository().await?; - let mut rng = state.rng(); - let mailer = state.mailer(); - let clock = state.clock(); - - let language = job - .language() - .and_then(|l| l.parse().ok()) - .unwrap_or(locale!("en").into()); - - // Lookup the user email - let user_email = repo - .user_email() - .lookup(job.user_email_id()) - .await? - .context("User email not found")?; - - // Lookup the user associated with the email - let user = repo - .user() - .lookup(user_email.user_id) - .await? - .context("User not found")?; - - // Generate a verification code - let range = Uniform::::from(0..1_000_000); - let code = rng.sample(range); - let code = format!("{code:06}"); - - let address: Address = user_email.email.parse()?; - - // Save the verification code in the database - let verification = repo - .user_email() - .add_verification_code( - &mut rng, - &clock, - &user_email, - Duration::try_hours(8).unwrap(), - code, - ) - .await?; - - // And send the verification email - let mailbox = Mailbox::new(Some(user.username.clone()), address); - - let context = - EmailVerificationContext::new(user.clone(), verification.clone()).with_language(language); - - mailer.send_verification_email(mailbox, &context).await?; - - info!( - email.id = %user_email.id, - "Verification email sent" - ); - - repo.save().await?; - - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let verify_email_worker = - crate::build!(VerifyEmailJob => verify_email, suffix, state, storage_factory); - - monitor.register(verify_email_worker) +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; + +#[async_trait] +impl RunnableJob for VerifyEmailJob { + #[tracing::instrument( + name = "job.verify_email", + fields(user_email.id = %self.user_email_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let mailer = state.mailer(); + let clock = state.clock(); + + let language = self + .language() + .and_then(|l| l.parse().ok()) + .unwrap_or(locale!("en").into()); + + // Lookup the user email + let user_email = repo + .user_email() + .lookup(self.user_email_id()) + .await? + .context("User email not found")?; + + // Lookup the user associated with the email + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("User not found")?; + + // Generate a verification code + let range = Uniform::::from(0..1_000_000); + let code = rng.sample(range); + let code = format!("{code:06}"); + + let address: Address = user_email.email.parse()?; + + // Save the verification code in the database + let verification = repo + .user_email() + .add_verification_code( + &mut rng, + &clock, + &user_email, + Duration::try_hours(8).unwrap(), + code, + ) + .await?; + + // And send the verification email + let mailbox = Mailbox::new(Some(user.username.clone()), address); + + let context = EmailVerificationContext::new(user.clone(), verification.clone()) + .with_language(language); + + mailer.send_verification_email(mailbox, &context).await?; + + info!( + email.id = %user_email.id, + "Verification email sent" + ); + + repo.save().await?; + + Ok(()) + } } diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index e56a082c7..ad2ede868 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -18,14 +18,13 @@ use rand::SeedableRng; use sqlx::{Pool, Postgres}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; +// TODO: we need to have a way to schedule recurring tasks // mod database; -// mod email; -// mod matrix; +mod email; +mod matrix; mod new_queue; -// mod recovery; -// mod storage; -// mod user; -// mod utils; +mod recovery; +mod user; #[derive(Clone)] struct State { @@ -111,6 +110,15 @@ pub async fn init( ); let mut worker = self::new_queue::QueueWorker::new(state, cancellation_token).await?; + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + task_tracker.spawn(async move { if let Err(e) = worker.run().await { tracing::error!( diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index 3cc09b272..f4596c05f 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -7,239 +7,239 @@ use std::collections::HashSet; use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use mas_data_model::Device; use mas_matrix::ProvisionRequest; use mas_storage::{ compat::CompatSessionFilter, - job::{JobRepositoryExt as _, JobWithSpanContext}, oauth2::OAuth2SessionFilter, - queue::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, SyncDevicesJob}, + queue::{ + DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, QueueJobRepositoryExt as _, + SyncDevicesJob, + }, user::{UserEmailRepository, UserRepository}, Pagination, RepositoryAccess, }; use tracing::info; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; /// Job to provision a user on the Matrix homeserver. -/// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id} -/// endpoint. -#[tracing::instrument( - name = "job.provision_user" - fields(user.id = %job.user_id()), - skip_all, - err(Debug), -)] -async fn provision_user( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - let mxid = matrix.mxid(&user.username); - let emails = repo - .user_email() - .all(&user) - .await? - .into_iter() - .filter(|email| email.confirmed_at.is_some()) - .map(|email| email.email) - .collect(); - let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); - - if let Some(display_name) = job.display_name_to_set() { - request = request.set_displayname(display_name.to_owned()); - } +/// This works by doing a PUT request to the +/// /_synapse/admin/v2/users/{user_id} endpoint. +#[async_trait] +impl RunnableJob for ProvisionUserJob { + #[tracing::instrument( + name = "job.provision_user" + fields(user.id = %self.user_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let clock = state.clock(); + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + let mxid = matrix.mxid(&user.username); + let emails = repo + .user_email() + .all(&user) + .await? + .into_iter() + .filter(|email| email.confirmed_at.is_some()) + .map(|email| email.email) + .collect(); + let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); + + if let Some(display_name) = self.display_name_to_set() { + request = request.set_displayname(display_name.to_owned()); + } - let created = matrix.provision_user(&request).await?; + let created = matrix.provision_user(&request).await?; - if created { - info!(%user.id, %mxid, "User created"); - } else { - info!(%user.id, %mxid, "User updated"); - } + if created { + info!(%user.id, %mxid, "User created"); + } else { + info!(%user.id, %mxid, "User updated"); + } - // Schedule a device sync job - let sync_device_job = SyncDevicesJob::new(&user); - repo.job().schedule_job(sync_device_job).await?; + // Schedule a device sync job + let sync_device_job = SyncDevicesJob::new(&user); + repo.queue_job() + .schedule_job(&mut rng, &clock, sync_device_job) + .await?; - repo.save().await?; + repo.save().await?; - Ok(()) + Ok(()) + } } /// Job to provision a device on the Matrix homeserver. /// /// This job is deprecated and therefore just schedules a [`SyncDevicesJob`] -#[tracing::instrument( - name = "job.provision_device" - fields( - user.id = %job.user_id(), - device.id = %job.device_id(), - ), - skip_all, - err(Debug), -)] -async fn provision_device( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Schedule a device sync job - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; - - Ok(()) +#[async_trait] +impl RunnableJob for ProvisionDeviceJob { + #[tracing::instrument( + name = "job.provision_device" + fields( + user.id = %self.user_id(), + device.id = %self.device_id(), + ), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let clock = state.clock(); + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Schedule a device sync job + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; + + Ok(()) + } } /// Job to delete a device from a user's account. /// /// This job is deprecated and therefore just schedules a [`SyncDevicesJob`] -#[tracing::instrument( - name = "job.delete_device" - fields( - user.id = %job.user_id(), - device.id = %job.device_id(), - ), - skip_all, - err(Debug), -)] -async fn delete_device( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Schedule a device sync job - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; - - Ok(()) +#[async_trait] +impl RunnableJob for DeleteDeviceJob { + #[tracing::instrument( + name = "job.delete_device" + fields( + user.id = %self.user_id(), + device.id = %self.device_id(), + ), + skip_all, + err(Debug), + )] + #[tracing::instrument( + name = "job.delete_device" + fields( + user.id = %self.user_id(), + device.id = %self.device_id(), + ), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let mut rng = state.rng(); + let clock = state.clock(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Schedule a device sync job + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; + + Ok(()) + } } /// Job to sync the list of devices of a user with the homeserver. -#[tracing::instrument( - name = "job.sync_devices", - fields(user.id = %job.user_id()), - skip_all, - err(Debug), -)] -async fn sync_devices( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Lock the user sync to make sure we don't get into a race condition - repo.user().acquire_lock_for_sync(&user).await?; - - let mut devices = HashSet::new(); - - // Cycle through all the compat sessions of the user, and grab the devices - let mut cursor = Pagination::first(100); - loop { - let page = repo - .compat_session() - .list( - CompatSessionFilter::new().for_user(&user).active_only(), - cursor, - ) - .await?; - - for (compat_session, _) in page.edges { - devices.insert(compat_session.device.as_str().to_owned()); - cursor = cursor.after(compat_session.id); - } +#[async_trait] +impl RunnableJob for SyncDevicesJob { + #[tracing::instrument( + name = "job.sync_devices", + fields(user.id = %self.user_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&user).await?; + + let mut devices = HashSet::new(); + + // Cycle through all the compat sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .compat_session() + .list( + CompatSessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for (compat_session, _) in page.edges { + devices.insert(compat_session.device.as_str().to_owned()); + cursor = cursor.after(compat_session.id); + } - if !page.has_next_page { - break; + if !page.has_next_page { + break; + } } - } - // Cycle though all the oauth2 sessions of the user, and grab the devices - let mut cursor = Pagination::first(100); - loop { - let page = repo - .oauth2_session() - .list( - OAuth2SessionFilter::new().for_user(&user).active_only(), - cursor, - ) - .await?; - - for oauth2_session in page.edges { - for scope in &*oauth2_session.scope { - if let Some(device) = Device::from_scope_token(scope) { - devices.insert(device.as_str().to_owned()); + // Cycle though all the oauth2 sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .oauth2_session() + .list( + OAuth2SessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for oauth2_session in page.edges { + for scope in &*oauth2_session.scope { + if let Some(device) = Device::from_scope_token(scope) { + devices.insert(device.as_str().to_owned()); + } } - } - cursor = cursor.after(oauth2_session.id); - } + cursor = cursor.after(oauth2_session.id); + } - if !page.has_next_page { - break; + if !page.has_next_page { + break; + } } - } - let mxid = matrix.mxid(&user.username); - matrix.sync_devices(&mxid, devices).await?; + let mxid = matrix.mxid(&user.username); + matrix.sync_devices(&mxid, devices).await?; - // We kept the connection until now, so that we still hold the lock on the user - // throughout the sync - repo.save().await?; + // We kept the connection until now, so that we still hold the lock on the user + // throughout the sync + repo.save().await?; - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let provision_user_worker = - crate::build!(ProvisionUserJob => provision_user, suffix, state, storage_factory); - let provision_device_worker = - crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory); - let delete_device_worker = - crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory); - let sync_devices_worker = - crate::build!(SyncDevicesJob => sync_devices, suffix, state, storage_factory); - - monitor - .register(provision_user_worker) - .register(provision_device_worker) - .register(delete_device_worker) - .register(sync_devices_worker) + Ok(()) + } } diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index f90b72011..42a037af4 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -3,12 +3,12 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; use mas_storage::{ - queue::{InsertableJob, Job, Worker}, + queue::{InsertableJob, Job, JobMetadata, Worker}, Clock, RepositoryAccess, RepositoryError, }; use mas_storage_pg::{DatabaseError, PgRepository}; @@ -20,12 +20,42 @@ use sqlx::{ Acquire, Either, }; use thiserror::Error; +use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; +use tracing::{Instrument as _, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt as _; +use ulid::Ulid; use crate::State; +type JobPayload = serde_json::Value; + +#[derive(Clone)] +pub struct JobContext { + pub id: Ulid, + pub metadata: JobMetadata, + pub queue_name: String, + pub cancellation_token: CancellationToken, +} + +impl JobContext { + pub fn span(&self) -> Span { + let span = tracing::info_span!( + parent: Span::none(), + "job.run", + job.id = %self.id, + job.queue_name = self.queue_name, + job.attempt = self.attempt, + ); + + span.add_link(self.metadata.span_context()); + + span + } +} + pub trait FromJob { - fn from_job(job: &Job) -> Result + fn from_job(payload: JobPayload) -> Result where Self: Sized; } @@ -34,14 +64,14 @@ impl FromJob for T where T: DeserializeOwned, { - fn from_job(job: &Job) -> Result { - serde_json::from_value(job.payload.clone()).map_err(Into::into) + fn from_job(payload: JobPayload) -> Result { + serde_json::from_value(payload).map_err(Into::into) } } #[async_trait] pub trait RunnableJob: FromJob + Send + 'static { - async fn run(&self, state: &State) -> Result<(), anyhow::Error>; + async fn run(&self, state: &State, context: JobContext) -> Result<(), anyhow::Error>; } fn box_runnable_job(job: T) -> Box { @@ -79,7 +109,13 @@ pub enum QueueRunnerError { const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900); const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100); -type JobFactory = Box Box + Send>; +// How many jobs can we run concurrently +const MAX_CONCURRENT_JOBS: usize = 10; + +// How many jobs can we fetch at once +const MAX_JOBS_TO_FETCH: usize = 5; + +type JobFactory = Arc Box + Send + Sync>; pub struct QueueWorker { rng: ChaChaRng, @@ -89,7 +125,14 @@ pub struct QueueWorker { am_i_leader: bool, last_heartbeat: DateTime, cancellation_token: CancellationToken, + state: State, + running_jobs: JoinSet>, + job_contexts: HashMap, factories: HashMap<&'static str, JobFactory>, + + #[allow(clippy::type_complexity)] + last_join_result: + Option), tokio::task::JoinError>>, } impl QueueWorker { @@ -115,6 +158,12 @@ impl QueueWorker { .await .map_err(QueueRunnerError::SetupListener)?; + // We get notifications when a job is available on this channel + listener + .listen("queue_available") + .await + .map_err(QueueRunnerError::SetupListener)?; + let txn = listener .begin() .await @@ -139,14 +188,22 @@ impl QueueWorker { am_i_leader: false, last_heartbeat: now, cancellation_token, + state, + job_contexts: HashMap::new(), + running_jobs: JoinSet::new(), factories: HashMap::new(), + last_join_result: None, }) } pub fn register_handler(&mut self) -> &mut Self { - // TODO: error handling - let factory = |job: &Job| box_runnable_job(T::from_job(job).unwrap()); - self.factories.insert(T::QUEUE_NAME, Box::new(factory)); + // There is a potential panic here, which is fine as it's going to be caught + // within the job task + let factory = |payload: JobPayload| { + box_runnable_job(T::from_job(payload).expect("Failed to deserialize job")) + }; + + self.factories.insert(T::QUEUE_NAME, Arc::new(factory)); self } @@ -164,6 +221,7 @@ impl QueueWorker { async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { self.wait_until_wakeup().await?; + // TODO: join all the jobs handles when shutting down if self.cancellation_token.is_cancelled() { return Ok(()); } @@ -214,6 +272,8 @@ impl QueueWorker { .sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION)); let wakeup_sleep = tokio::time::sleep(sleep_duration); + // TODO: add metrics to track the wake up reasons + tokio::select! { () = self.cancellation_token.cancelled() => { tracing::debug!("Woke up from cancellation"); @@ -223,6 +283,11 @@ impl QueueWorker { tracing::debug!("Woke up from sleep"); }, + Some(result) = self.running_jobs.join_next_with_id() => { + tracing::debug!("Joined job task"); + self.last_join_result = Some(result); + }, + notification = self.listener.recv() => { match notification { Ok(notification) => { @@ -281,6 +346,127 @@ impl QueueWorker { .try_get_leader_lease(&self.clock, &self.registration) .await?; + // Find any job task which finished + // If we got woken up by a join on the joinset, it will be stored in the + // last_join_result so that we don't loose it + + if self.last_join_result.is_none() { + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + + while let Some(result) = self.last_join_result.take() { + // TODO: add metrics to track the job status and the time it took + let context = match result { + Ok((id, Ok(()))) => { + // The job succeeded + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::info!( + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job completed" + ); + + context + } + Ok((id, Err(e))) => { + // The job failed + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::error!( + error = ?e, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job failed" + ); + + // TODO: reschedule the job + + context + } + Err(e) => { + // The job crashed (or was cancelled) + let id = e.id(); + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job crashed" + ); + + // TODO: reschedule the job + + context + } + }; + + repo.queue_job() + .mark_as_completed(&self.clock, context.id) + .await?; + + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + + // Compute how many jobs we should fetch at most + let max_jobs_to_fetch = MAX_CONCURRENT_JOBS + .saturating_sub(self.running_jobs.len()) + .max(MAX_JOBS_TO_FETCH); + + if max_jobs_to_fetch == 0 { + tracing::warn!("Internal job queue is full, not fetching any new jobs"); + } else { + // Grab a few jobs in the queue + let queues = self.factories.keys().copied().collect::>(); + let jobs = repo + .queue_job() + .reserve(&self.clock, &self.registration, &queues, max_jobs_to_fetch) + .await?; + + for Job { + id, + queue_name, + payload, + metadata, + } in jobs + { + let cancellation_token = self.cancellation_token.child_token(); + let factory = self.factories.get(queue_name.as_str()).cloned(); + let context = JobContext { + id, + metadata, + queue_name, + cancellation_token, + }; + + let task = { + let context = context.clone(); + let span = context.span(); + let state = self.state.clone(); + async move { + // We should never crash, but in case we do, we do that in the task and + // don't crash the worker + let job = factory.expect("unknown job factory")(payload); + job.run(&state, context).await + } + .instrument(span) + }; + + let handle = self.running_jobs.spawn(task); + self.job_contexts.insert(handle.id(), context); + } + } + // After this point, we are locking the leader table, so it's important that we // commit as soon as possible to not block the other workers for too long repo.into_inner() @@ -353,6 +539,8 @@ impl QueueWorker { .shutdown_dead_workers(&self.clock, Duration::minutes(2)) .await?; + // TODO: mark tasks those workers had as lost + // Release the leader lock let txn = repo .into_inner() diff --git a/crates/tasks/src/recovery.rs b/crates/tasks/src/recovery.rs index 79f469b06..cd3787d2a 100644 --- a/crates/tasks/src/recovery.rs +++ b/crates/tasks/src/recovery.rs @@ -5,11 +5,10 @@ // Please see LICENSE in the repository root for full details. use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use mas_email::{Address, Mailbox}; use mas_i18n::DataLocale; use mas_storage::{ - job::JobWithSpanContext, queue::SendAccountRecoveryEmailsJob, user::{UserEmailFilter, UserRecoveryRepository}, Pagination, RepositoryAccess, @@ -18,117 +17,108 @@ use mas_templates::{EmailRecoveryContext, TemplateContext}; use rand::distributions::{Alphanumeric, DistString}; use tracing::{error, info}; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; /// Job to send account recovery emails for a given recovery session. -#[tracing::instrument( - name = "job.send_account_recovery_email", - fields( - user_recovery_session.id = %job.user_recovery_session_id(), - user_recovery_session.email, - ), - skip_all, - err(Debug), -)] -async fn send_account_recovery_email_job( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let clock = state.clock(); - let mailer = state.mailer(); - let url_builder = state.url_builder(); - let mut rng = state.rng(); - let mut repo = state.repository().await?; - - let session = repo - .user_recovery() - .lookup_session(job.user_recovery_session_id()) - .await? - .context("User recovery session not found")?; - - tracing::Span::current().record("user_recovery_session.email", &session.email); - - if session.consumed_at.is_some() { - info!("Recovery session already consumed, not sending email"); - return Ok(()); - } +#[async_trait] +impl RunnableJob for SendAccountRecoveryEmailsJob { + #[tracing::instrument( + name = "job.send_account_recovery_email", + fields( + user_recovery_session.id = %self.user_recovery_session_id(), + user_recovery_session.email, + ), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let clock = state.clock(); + let mailer = state.mailer(); + let url_builder = state.url_builder(); + let mut rng = state.rng(); + let mut repo = state.repository().await?; + + let session = repo + .user_recovery() + .lookup_session(self.user_recovery_session_id()) + .await? + .context("User recovery session not found")?; + + tracing::Span::current().record("user_recovery_session.email", &session.email); + + if session.consumed_at.is_some() { + info!("Recovery session already consumed, not sending email"); + return Ok(()); + } - let mut cursor = Pagination::first(50); - - let lang: DataLocale = session - .locale - .parse() - .context("Invalid locale in database on recovery session")?; - - loop { - let page = repo - .user_email() - .list( - UserEmailFilter::new() - .for_email(&session.email) - .verified_only(), - cursor, - ) - .await?; - - for email in page.edges { - let ticket = Alphanumeric.sample_string(&mut rng, 32); - - let ticket = repo - .user_recovery() - .add_ticket(&mut rng, &clock, &session, &email, ticket) - .await?; + let mut cursor = Pagination::first(50); + + let lang: DataLocale = session + .locale + .parse() + .context("Invalid locale in database on recovery session")?; - let user_email = repo + loop { + let page = repo .user_email() - .lookup(email.id) - .await? - .context("User email not found")?; - - let user = repo - .user() - .lookup(user_email.user_id) - .await? - .context("User not found")?; - - let url = url_builder.account_recovery_link(ticket.ticket); - - let address: Address = user_email.email.parse()?; - let mailbox = Mailbox::new(Some(user.username.clone()), address); - - info!("Sending recovery email to {}", mailbox); - let context = - EmailRecoveryContext::new(user, session.clone(), url).with_language(lang.clone()); - - // XXX: we only log if the email fails to send, to avoid stopping the loop - if let Err(e) = mailer.send_recovery_email(mailbox, &context).await { - error!( - error = &e as &dyn std::error::Error, - "Failed to send recovery email" - ); - } + .list( + UserEmailFilter::new() + .for_email(&session.email) + .verified_only(), + cursor, + ) + .await?; - cursor = cursor.after(email.id); - } + for email in page.edges { + let ticket = Alphanumeric.sample_string(&mut rng, 32); - if !page.has_next_page { - break; - } - } + let ticket = repo + .user_recovery() + .add_ticket(&mut rng, &clock, &session, &email, ticket) + .await?; - repo.save().await?; + let user_email = repo + .user_email() + .lookup(email.id) + .await? + .context("User email not found")?; - Ok(()) -} + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("User not found")?; -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let send_user_recovery_email_worker = crate::build!(SendAccountRecoveryEmailsJob => send_account_recovery_email_job, suffix, state, storage_factory); + let url = url_builder.account_recovery_link(ticket.ticket); - monitor.register(send_user_recovery_email_worker) + let address: Address = user_email.email.parse()?; + let mailbox = Mailbox::new(Some(user.username.clone()), address); + + info!("Sending recovery email to {}", mailbox); + let context = EmailRecoveryContext::new(user, session.clone(), url) + .with_language(lang.clone()); + + // XXX: we only log if the email fails to send, to avoid stopping the loop + if let Err(e) = mailer.send_recovery_email(mailbox, &context).await { + error!( + error = &e as &dyn std::error::Error, + "Failed to send recovery email" + ); + } + + cursor = cursor.after(email.id); + } + + if !page.has_next_page { + break; + } + } + + repo.save().await?; + + Ok(()) + } } diff --git a/crates/tasks/src/storage/from_row.rs b/crates/tasks/src/storage/from_row.rs deleted file mode 100644 index 5acf6848a..000000000 --- a/crates/tasks/src/storage/from_row.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::str::FromStr; - -use apalis_core::{context::JobContext, job::JobId, request::JobRequest, worker::WorkerId}; -use chrono::{DateTime, Utc}; -use serde_json::Value; -use sqlx::Row; - -/// Wrapper for [`JobRequest`] -pub(crate) struct SqlJobRequest(JobRequest); - -impl From> for JobRequest { - fn from(val: SqlJobRequest) -> Self { - val.0 - } -} - -impl<'r, T: serde::de::DeserializeOwned> sqlx::FromRow<'r, sqlx::postgres::PgRow> - for SqlJobRequest -{ - fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { - let job: Value = row.try_get("job")?; - let id: JobId = - JobId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode { - index: "id".to_owned(), - source: Box::new(e), - })?; - let mut context = JobContext::new(id); - - let run_at = row.try_get("run_at")?; - context.set_run_at(run_at); - - let attempts = row.try_get("attempts").unwrap_or(0); - context.set_attempts(attempts); - - let max_attempts = row.try_get("max_attempts").unwrap_or(25); - context.set_max_attempts(max_attempts); - - let done_at: Option> = row.try_get("done_at").unwrap_or_default(); - context.set_done_at(done_at); - - let lock_at: Option> = row.try_get("lock_at").unwrap_or_default(); - context.set_lock_at(lock_at); - - let last_error = row.try_get("last_error").unwrap_or_default(); - context.set_last_error(last_error); - - let status: String = row.try_get("status")?; - context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode { - index: "job".to_owned(), - source: Box::new(e), - })?); - - let lock_by: Option = row.try_get("lock_by").unwrap_or_default(); - context.set_lock_by(lock_by.map(WorkerId::new)); - - Ok(SqlJobRequest(JobRequest::new_with_context( - serde_json::from_value(job).map_err(|e| sqlx::Error::ColumnDecode { - index: "job".to_owned(), - source: Box::new(e), - })?, - context, - ))) - } -} diff --git a/crates/tasks/src/storage/mod.rs b/crates/tasks/src/storage/mod.rs deleted file mode 100644 index 5f6e77e31..000000000 --- a/crates/tasks/src/storage/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Reimplementation of the [`apalis_sql::storage::PostgresStorage`] using a -//! shared connection for the [`PgListener`] - -mod from_row; -mod postgres; - -use self::from_row::SqlJobRequest; -pub(crate) use self::postgres::StorageFactory as PostgresStorageFactory; diff --git a/crates/tasks/src/storage/postgres.rs b/crates/tasks/src/storage/postgres.rs deleted file mode 100644 index f709579ed..000000000 --- a/crates/tasks/src/storage/postgres.rs +++ /dev/null @@ -1,391 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::{convert::TryInto, marker::PhantomData, ops::Add, sync::Arc, time::Duration}; - -use apalis_core::{ - error::JobStreamError, - job::{Job, JobId, JobStreamResult}, - request::JobRequest, - storage::{StorageError, StorageResult, StorageWorkerPulse}, - utils::Timer, - worker::WorkerId, -}; -use async_stream::try_stream; -use chrono::{DateTime, Utc}; -use event_listener::Event; -use futures_lite::{Stream, StreamExt}; -use serde::{de::DeserializeOwned, Serialize}; -use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row}; -use tokio::task::JoinHandle; - -use super::SqlJobRequest; - -pub struct StorageFactory { - pool: PgPool, - event: Arc, -} - -impl StorageFactory { - pub fn new(pool: Pool) -> Self { - StorageFactory { - pool, - event: Arc::new(Event::new()), - } - } - - pub async fn listen(self) -> Result, sqlx::Error> { - let mut listener = PgListener::connect_with(&self.pool).await?; - listener.listen("apalis::job").await?; - - let handle = tokio::spawn(async move { - loop { - let notification = listener.recv().await.expect("Failed to poll notification"); - self.event.notify(usize::MAX); - tracing::debug!(?notification, "Broadcast notification"); - } - }); - - Ok(handle) - } - - pub fn build(&self) -> Storage { - Storage { - pool: self.pool.clone(), - event: self.event.clone(), - job_type: PhantomData, - } - } -} - -/// Represents a [`apalis_core::storage::Storage`] that persists to Postgres -#[derive(Debug)] -pub struct Storage { - pool: PgPool, - event: Arc, - job_type: PhantomData, -} - -impl Clone for Storage { - fn clone(&self) -> Self { - Storage { - pool: self.pool.clone(), - event: self.event.clone(), - job_type: PhantomData, - } - } -} - -impl Storage { - fn stream_jobs( - &self, - worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> impl Stream, JobStreamError>> { - let pool = self.pool.clone(); - let sleeper = apalis_core::utils::timer::TokioTimer; - let worker_id = worker_id.clone(); - let event = self.event.clone(); - try_stream! { - loop { - // Wait for a notification or a timeout - let listener = event.listen(); - let interval = sleeper.sleep(interval); - futures_lite::future::race(interval, listener).await; - - let tx = pool.clone(); - let job_type = T::NAME; - let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);"; - let jobs: Vec> = sqlx::query_as(fetch_query) - .bind(worker_id.name()) - .bind(job_type) - // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html - .bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?) - .fetch_all(&tx) - .await.map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?; - for job in jobs { - yield job.into() - } - } - } - } - - async fn keep_alive_at( - &mut self, - worker_id: &WorkerId, - last_seen: DateTime, - ) -> StorageResult<()> { - let pool = self.pool.clone(); - - let worker_type = T::NAME; - let storage_name = std::any::type_name::(); - let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (id) DO - UPDATE SET last_seen = EXCLUDED.last_seen"; - sqlx::query(query) - .bind(worker_id.name()) - .bind(worker_type) - .bind(storage_name) - .bind(std::any::type_name::()) - .bind(last_seen) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } -} - -#[async_trait::async_trait] -impl apalis_core::storage::Storage for Storage -where - T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, -{ - type Output = T; - - /// Push a job to Postgres [Storage] - /// - /// # SQL Example - /// - /// ```sql - /// SELECT apalis.push_job(job_type::text, job::json); - /// ``` - async fn push(&mut self, job: Self::Output) -> StorageResult { - let id = JobId::new(); - let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)"; - let pool = self.pool.clone(); - let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; - let job_type = T::NAME; - sqlx::query(query) - .bind(job) - .bind(id.to_string()) - .bind(job_type) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(id) - } - - async fn schedule( - &mut self, - job: Self::Output, - on: chrono::DateTime, - ) -> StorageResult { - let query = - "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)"; - - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - - let id = JobId::new(); - let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; - let job_type = T::NAME; - sqlx::query(query) - .bind(job) - .bind(id.to_string()) - .bind(job_type) - .bind(on) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(id) - } - - async fn fetch_by_id(&self, job_id: &JobId) -> StorageResult>> { - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - - let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1"; - let res: Option> = sqlx::query_as(fetch_query) - .bind(job_id.to_string()) - .fetch_optional(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(res.map(Into::into)) - } - - async fn heartbeat(&mut self, pulse: StorageWorkerPulse) -> StorageResult { - match pulse { - StorageWorkerPulse::EnqueueScheduled { count: _ } => { - // Ideally jobs are queue via run_at. So this is not necessary - Ok(true) - } - - // Worker not seen in 5 minutes yet has running jobs - StorageWorkerPulse::ReenqueueOrphaned { count, .. } => { - let job_type = T::NAME; - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - let query = "UPDATE apalis.jobs - SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned' - WHERE id in - (SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id - WHERE status = 'Running' AND workers.last_seen < NOW() - INTERVAL '5 minutes' - AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);"; - sqlx::query(query) - .bind(job_type) - .bind(count) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(true) - } - - _ => unimplemented!(), - } - } - - async fn kill(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - /// Puts the job instantly back into the queue - /// Another [Worker] may consume - async fn retry(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - fn consume( - &mut self, - worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> JobStreamResult { - Box::pin( - self.stream_jobs(worker_id, interval, buffer_size) - .map(|r| r.map(Some)), - ) - } - async fn len(&self) -> StorageResult { - let pool = self.pool.clone(); - let query = "SELECT COUNT(*) AS count FROM apalis.jobs WHERE status = 'Pending'"; - let record = sqlx::query(query) - .fetch_one(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(record - .try_get("count") - .map_err(|e| StorageError::Database(Box::from(e)))?) - } - async fn ack(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - let query = - "UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn reschedule(&mut self, job: &JobRequest, wait: Duration) -> StorageResult<()> { - let pool = self.pool.clone(); - let job_id = job.id(); - - let wait: i64 = wait - .as_secs() - .try_into() - .map_err(|e| StorageError::Database(Box::new(e)))?; - let wait = chrono::Duration::microseconds(wait * 1000 * 1000); - // TODO: should we use a clock here? - #[allow(clippy::disallowed_methods)] - let run_at = Utc::now().add(wait); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(run_at) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn update_by_id( - &self, - job_id: &JobId, - job: &JobRequest, - ) -> StorageResult<()> { - let pool = self.pool.clone(); - let status = job.status().as_ref(); - let attempts = job.attempts(); - let done_at = *job.done_at(); - let lock_by = job.lock_by().clone(); - let lock_at = *job.lock_at(); - let last_error = job.last_error().clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7"; - sqlx::query(query) - .bind(status.to_owned()) - .bind(attempts) - .bind(done_at) - .bind(lock_by.as_ref().map(WorkerId::name)) - .bind(lock_at) - .bind(last_error) - .bind(job_id.to_string()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn keep_alive(&mut self, worker_id: &WorkerId) -> StorageResult<()> { - #[allow(clippy::disallowed_methods)] - let now = Utc::now(); - - self.keep_alive_at::(worker_id, now).await - } -} diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index b3d062bb4..ad4444be5 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -5,10 +5,9 @@ // Please see LICENSE in the repository root for full details. use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use mas_storage::{ compat::CompatSessionFilter, - job::JobWithSpanContext, oauth2::OAuth2SessionFilter, queue::{DeactivateUserJob, ReactivateUserJob}, user::{BrowserSessionFilter, UserRepository}, @@ -16,122 +15,106 @@ use mas_storage::{ }; use tracing::info; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; /// Job to deactivate a user, both locally and on the Matrix homeserver. -#[tracing::instrument( +#[async_trait] +impl RunnableJob for DeactivateUserJob { + #[tracing::instrument( name = "job.deactivate_user" - fields(user.id = %job.user_id(), erase = %job.hs_erase()), - skip_all, - err(Debug), -)] -async fn deactivate_user( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let clock = state.clock(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Let's first lock the user - let user = repo - .user() - .lock(&clock, user) - .await - .context("Failed to lock user")?; - - // Kill all sessions for the user - let n = repo - .browser_session() - .finish_bulk( - &clock, - BrowserSessionFilter::new().for_user(&user).active_only(), - ) - .await?; - info!(affected = n, "Killed all browser sessions for user"); - - let n = repo - .oauth2_session() - .finish_bulk( - &clock, - OAuth2SessionFilter::new().for_user(&user).active_only(), - ) - .await?; - info!(affected = n, "Killed all OAuth 2.0 sessions for user"); - - let n = repo - .compat_session() - .finish_bulk( - &clock, - CompatSessionFilter::new().for_user(&user).active_only(), - ) - .await?; - info!(affected = n, "Killed all compatibility sessions for user"); - - // Before calling back to the homeserver, commit the changes to the database, as - // we want the user to be locked out as soon as possible - repo.save().await?; - - let mxid = matrix.mxid(&user.username); - info!("Deactivating user {} on homeserver", mxid); - matrix.delete_user(&mxid, job.hs_erase()).await?; - - Ok(()) + fields(user.id = %self.user_id(), erase = %self.hs_erase()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let clock = state.clock(); + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Let's first lock the user + let user = repo + .user() + .lock(&clock, user) + .await + .context("Failed to lock user")?; + + // Kill all sessions for the user + let n = repo + .browser_session() + .finish_bulk( + &clock, + BrowserSessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all browser sessions for user"); + + let n = repo + .oauth2_session() + .finish_bulk( + &clock, + OAuth2SessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all OAuth 2.0 sessions for user"); + + let n = repo + .compat_session() + .finish_bulk( + &clock, + CompatSessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all compatibility sessions for user"); + + // Before calling back to the homeserver, commit the changes to the database, as + // we want the user to be locked out as soon as possible + repo.save().await?; + + let mxid = matrix.mxid(&user.username); + info!("Deactivating user {} on homeserver", mxid); + matrix.delete_user(&mxid, self.hs_erase()).await?; + + Ok(()) + } } /// Job to reactivate a user, both locally and on the Matrix homeserver. -#[tracing::instrument( - name = "job.reactivate_user", - fields(user.id = %job.user_id()), - skip_all, - err(Debug), -)] -pub async fn reactivate_user( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - let mxid = matrix.mxid(&user.username); - info!("Reactivating user {} on homeserver", mxid); - matrix.reactivate_user(&mxid).await?; - - // We want to unlock the user from our side only once it has been reactivated on - // the homeserver - let _user = repo.user().unlock(user).await?; - repo.save().await?; - - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let deactivate_user_worker = - crate::build!(DeactivateUserJob => deactivate_user, suffix, state, storage_factory); - - let reactivate_user_worker = - crate::build!(ReactivateUserJob => reactivate_user, suffix, state, storage_factory); - - monitor - .register(deactivate_user_worker) - .register(reactivate_user_worker) +#[async_trait] +impl RunnableJob for ReactivateUserJob { + #[tracing::instrument( + name = "job.reactivate_user", + fields(user.id = %self.user_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + let mxid = matrix.mxid(&user.username); + info!("Reactivating user {} on homeserver", mxid); + matrix.reactivate_user(&mxid).await?; + + // We want to unlock the user from our side only once it has been reactivated on + // the homeserver + let _user = repo.user().unlock(user).await?; + repo.save().await?; + + Ok(()) + } } diff --git a/crates/tasks/src/utils.rs b/crates/tasks/src/utils.rs deleted file mode 100644 index c5862f9cf..000000000 --- a/crates/tasks/src/utils.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use apalis_core::{job::Job, request::JobRequest}; -use mas_storage::job::JobWithSpanContext; -use mas_tower::{ - make_span_fn, DurationRecorderLayer, FnWrapper, IdentityLayer, InFlightCounterLayer, - TraceLayer, KV, -}; -use opentelemetry::{trace::SpanContext, Key, KeyValue}; -use tracing::info_span; -use tracing_opentelemetry::OpenTelemetrySpanExt; - -const JOB_NAME: Key = Key::from_static_str("job.name"); -const JOB_STATUS: Key = Key::from_static_str("job.status"); - -/// Represents a job that can may have a span context attached to it. -pub trait TracedJob: Job { - /// Returns the span context for this job, if any. - /// - /// The default implementation returns `None`. - fn span_context(&self) -> Option { - None - } -} - -/// Implements [`TracedJob`] for any job with the [`JobWithSpanContext`] -/// wrapper. -impl TracedJob for JobWithSpanContext { - fn span_context(&self) -> Option { - JobWithSpanContext::span_context(self) - } -} - -fn make_span_for_job_request(req: &JobRequest) -> tracing::Span { - let span = info_span!( - "job.run", - "otel.kind" = "consumer", - "otel.status_code" = tracing::field::Empty, - "job.id" = %req.id(), - "job.attempts" = req.attempts(), - "job.name" = J::NAME, - ); - - if let Some(context) = req.inner().span_context() { - span.add_link(context); - } - - span -} - -type TraceLayerForJob = - TraceLayer) -> tracing::Span>, KV<&'static str>, KV<&'static str>>; - -pub(crate) fn trace_layer() -> TraceLayerForJob -where - J: TracedJob, -{ - TraceLayer::new(make_span_fn( - make_span_for_job_request:: as fn(&JobRequest) -> tracing::Span, - )) - .on_response(KV("otel.status_code", "OK")) - .on_error(KV("otel.status_code", "ERROR")) -} - -type MetricsLayerForJob = ( - IdentityLayer>, - DurationRecorderLayer, - InFlightCounterLayer, -); - -pub(crate) fn metrics_layer() -> MetricsLayerForJob -where - J: Job, -{ - let duration_recorder = DurationRecorderLayer::new("job.run.duration") - .on_request(JOB_NAME.string(J::NAME)) - .on_response(JOB_STATUS.string("success")) - .on_error(JOB_STATUS.string("error")); - let in_flight_counter = - InFlightCounterLayer::new("job.run.active").on_request(JOB_NAME.string(J::NAME)); - - ( - IdentityLayer::default(), - duration_recorder, - in_flight_counter, - ) -} From ca44be7b54dcd8cf949ed56eb0352c2aee293544 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 20 Nov 2024 17:03:00 +0100 Subject: [PATCH 12/17] Decide in each job whether it should retry or not --- crates/tasks/src/email.rs | 32 +++++++----- crates/tasks/src/matrix.rs | 96 ++++++++++++++++++++--------------- crates/tasks/src/new_queue.rs | 76 +++++++++++++++++++++++---- crates/tasks/src/recovery.rs | 39 ++++++++------ crates/tasks/src/user.rs | 54 ++++++++++++-------- 5 files changed, 198 insertions(+), 99 deletions(-) diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index 3afbab8ce..25cbf2e7d 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -15,7 +15,7 @@ use rand::{distributions::Uniform, Rng}; use tracing::info; use crate::{ - new_queue::{JobContext, RunnableJob}, + new_queue::{JobContext, JobError, RunnableJob}, State, }; @@ -25,10 +25,10 @@ impl RunnableJob for VerifyEmailJob { name = "job.verify_email", fields(user_email.id = %self.user_email_id()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { - let mut repo = state.repository().await?; + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { + let mut repo = state.repository().await.map_err(JobError::retry)?; let mut rng = state.rng(); let mailer = state.mailer(); let clock = state.clock(); @@ -42,22 +42,26 @@ impl RunnableJob for VerifyEmailJob { let user_email = repo .user_email() .lookup(self.user_email_id()) - .await? - .context("User email not found")?; + .await + .map_err(JobError::retry)? + .context("User email not found") + .map_err(JobError::fail)?; // Lookup the user associated with the email let user = repo .user() .lookup(user_email.user_id) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Generate a verification code let range = Uniform::::from(0..1_000_000); let code = rng.sample(range); let code = format!("{code:06}"); - let address: Address = user_email.email.parse()?; + let address: Address = user_email.email.parse().map_err(JobError::fail)?; // Save the verification code in the database let verification = repo @@ -69,7 +73,8 @@ impl RunnableJob for VerifyEmailJob { Duration::try_hours(8).unwrap(), code, ) - .await?; + .await + .map_err(JobError::retry)?; // And send the verification email let mailbox = Mailbox::new(Some(user.username.clone()), address); @@ -77,14 +82,17 @@ impl RunnableJob for VerifyEmailJob { let context = EmailVerificationContext::new(user.clone(), verification.clone()) .with_language(language); - mailer.send_verification_email(mailbox, &context).await?; + mailer + .send_verification_email(mailbox, &context) + .await + .map_err(JobError::retry)?; info!( email.id = %user_email.id, "Verification email sent" ); - repo.save().await?; + repo.save().await.map_err(JobError::retry)?; Ok(()) } diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index f4596c05f..0f58773b3 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -23,7 +23,7 @@ use mas_storage::{ use tracing::info; use crate::{ - new_queue::{JobContext, RunnableJob}, + new_queue::{JobContext, JobError, RunnableJob}, State, }; @@ -36,25 +36,28 @@ impl RunnableJob for ProvisionUserJob { name = "job.provision_user" fields(user.id = %self.user_id()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let mut rng = state.rng(); let clock = state.clock(); let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; let mxid = matrix.mxid(&user.username); let emails = repo .user_email() .all(&user) - .await? + .await + .map_err(JobError::retry)? .into_iter() .filter(|email| email.confirmed_at.is_some()) .map(|email| email.email) @@ -65,7 +68,10 @@ impl RunnableJob for ProvisionUserJob { request = request.set_displayname(display_name.to_owned()); } - let created = matrix.provision_user(&request).await?; + let created = matrix + .provision_user(&request) + .await + .map_err(JobError::retry)?; if created { info!(%user.id, %mxid, "User created"); @@ -77,9 +83,10 @@ impl RunnableJob for ProvisionUserJob { let sync_device_job = SyncDevicesJob::new(&user); repo.queue_job() .schedule_job(&mut rng, &clock, sync_device_job) - .await?; + .await + .map_err(JobError::retry)?; - repo.save().await?; + repo.save().await.map_err(JobError::retry)?; Ok(()) } @@ -97,23 +104,26 @@ impl RunnableJob for ProvisionDeviceJob { device.id = %self.device_id(), ), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { - let mut repo = state.repository().await?; + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { + let mut repo = state.repository().await.map_err(JobError::retry)?; let mut rng = state.rng(); let clock = state.clock(); let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Schedule a device sync job repo.queue_job() .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) - .await?; + .await + .map_err(JobError::retry)?; Ok(()) } @@ -131,32 +141,26 @@ impl RunnableJob for DeleteDeviceJob { device.id = %self.device_id(), ), skip_all, - err(Debug), + err, )] - #[tracing::instrument( - name = "job.delete_device" - fields( - user.id = %self.user_id(), - device.id = %self.device_id(), - ), - skip_all, - err(Debug), - )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let mut rng = state.rng(); let clock = state.clock(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Schedule a device sync job repo.queue_job() .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) - .await?; + .await + .map_err(JobError::retry)?; Ok(()) } @@ -169,20 +173,25 @@ impl RunnableJob for SyncDevicesJob { name = "job.sync_devices", fields(user.id = %self.user_id()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Lock the user sync to make sure we don't get into a race condition - repo.user().acquire_lock_for_sync(&user).await?; + repo.user() + .acquire_lock_for_sync(&user) + .await + .map_err(JobError::retry)?; let mut devices = HashSet::new(); @@ -195,7 +204,8 @@ impl RunnableJob for SyncDevicesJob { CompatSessionFilter::new().for_user(&user).active_only(), cursor, ) - .await?; + .await + .map_err(JobError::retry)?; for (compat_session, _) in page.edges { devices.insert(compat_session.device.as_str().to_owned()); @@ -216,7 +226,8 @@ impl RunnableJob for SyncDevicesJob { OAuth2SessionFilter::new().for_user(&user).active_only(), cursor, ) - .await?; + .await + .map_err(JobError::retry)?; for oauth2_session in page.edges { for scope in &*oauth2_session.scope { @@ -234,11 +245,14 @@ impl RunnableJob for SyncDevicesJob { } let mxid = matrix.mxid(&user.username); - matrix.sync_devices(&mxid, devices).await?; + matrix + .sync_devices(&mxid, devices) + .await + .map_err(JobError::retry)?; // We kept the connection until now, so that we still hold the lock on the user // throughout the sync - repo.save().await?; + repo.save().await.map_err(JobError::retry)?; Ok(()) } diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index 42a037af4..ba707cff8 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -54,6 +54,47 @@ impl JobContext { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum JobErrorDecision { + Retry, + + #[default] + Fail, +} + +impl std::fmt::Display for JobErrorDecision { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Retry => f.write_str("retry"), + Self::Fail => f.write_str("fail"), + } + } +} + +#[derive(Debug, Error)] +#[error("Job failed to run, will {decision}")] +pub struct JobError { + decision: JobErrorDecision, + #[source] + error: anyhow::Error, +} + +impl JobError { + pub fn retry>(error: T) -> Self { + Self { + decision: JobErrorDecision::Retry, + error: error.into(), + } + } + + pub fn fail>(error: T) -> Self { + Self { + decision: JobErrorDecision::Fail, + error: error.into(), + } + } +} + pub trait FromJob { fn from_job(payload: JobPayload) -> Result where @@ -71,7 +112,7 @@ where #[async_trait] pub trait RunnableJob: FromJob + Send + 'static { - async fn run(&self, state: &State, context: JobContext) -> Result<(), anyhow::Error>; + async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError>; } fn box_runnable_job(job: T) -> Box { @@ -126,13 +167,13 @@ pub struct QueueWorker { last_heartbeat: DateTime, cancellation_token: CancellationToken, state: State, - running_jobs: JoinSet>, + running_jobs: JoinSet>, job_contexts: HashMap, factories: HashMap<&'static str, JobFactory>, #[allow(clippy::type_complexity)] last_join_result: - Option), tokio::task::JoinError>>, + Option), tokio::task::JoinError>>, } impl QueueWorker { @@ -379,14 +420,27 @@ impl QueueWorker { .remove(&id) .expect("Job context not found"); - tracing::error!( - error = ?e, - job.id = %context.id, - job.queue_name = %context.queue_name, - "Job failed" - ); - - // TODO: reschedule the job + match e.decision { + JobErrorDecision::Fail => { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job failed" + ); + } + + JobErrorDecision::Retry => { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job failed, will retry" + ); + + // TODO: reschedule the job + } + } context } diff --git a/crates/tasks/src/recovery.rs b/crates/tasks/src/recovery.rs index cd3787d2a..294d7f1ba 100644 --- a/crates/tasks/src/recovery.rs +++ b/crates/tasks/src/recovery.rs @@ -18,7 +18,7 @@ use rand::distributions::{Alphanumeric, DistString}; use tracing::{error, info}; use crate::{ - new_queue::{JobContext, RunnableJob}, + new_queue::{JobContext, JobError, RunnableJob}, State, }; @@ -32,20 +32,22 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { user_recovery_session.email, ), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let clock = state.clock(); let mailer = state.mailer(); let url_builder = state.url_builder(); let mut rng = state.rng(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let session = repo .user_recovery() .lookup_session(self.user_recovery_session_id()) - .await? - .context("User recovery session not found")?; + .await + .map_err(JobError::retry)? + .context("User recovery session not found") + .map_err(JobError::fail)?; tracing::Span::current().record("user_recovery_session.email", &session.email); @@ -59,7 +61,8 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { let lang: DataLocale = session .locale .parse() - .context("Invalid locale in database on recovery session")?; + .context("Invalid locale in database on recovery session") + .map_err(JobError::fail)?; loop { let page = repo @@ -70,7 +73,8 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { .verified_only(), cursor, ) - .await?; + .await + .map_err(JobError::retry)?; for email in page.edges { let ticket = Alphanumeric.sample_string(&mut rng, 32); @@ -78,23 +82,28 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { let ticket = repo .user_recovery() .add_ticket(&mut rng, &clock, &session, &email, ticket) - .await?; + .await + .map_err(JobError::retry)?; let user_email = repo .user_email() .lookup(email.id) - .await? - .context("User email not found")?; + .await + .map_err(JobError::retry)? + .context("User email not found") + .map_err(JobError::fail)?; let user = repo .user() .lookup(user_email.user_id) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; let url = url_builder.account_recovery_link(ticket.ticket); - let address: Address = user_email.email.parse()?; + let address: Address = user_email.email.parse().map_err(JobError::fail)?; let mailbox = Mailbox::new(Some(user.username.clone()), address); info!("Sending recovery email to {}", mailbox); @@ -117,7 +126,7 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { } } - repo.save().await?; + repo.save().await.map_err(JobError::fail)?; Ok(()) } diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index ad4444be5..eaa9d2b43 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -16,7 +16,7 @@ use mas_storage::{ use tracing::info; use crate::{ - new_queue::{JobContext, RunnableJob}, + new_queue::{JobContext, JobError, RunnableJob}, State, }; @@ -27,25 +27,28 @@ impl RunnableJob for DeactivateUserJob { name = "job.deactivate_user" fields(user.id = %self.user_id(), erase = %self.hs_erase()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let clock = state.clock(); let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Let's first lock the user let user = repo .user() .lock(&clock, user) .await - .context("Failed to lock user")?; + .context("Failed to lock user") + .map_err(JobError::retry)?; // Kill all sessions for the user let n = repo @@ -54,7 +57,8 @@ impl RunnableJob for DeactivateUserJob { &clock, BrowserSessionFilter::new().for_user(&user).active_only(), ) - .await?; + .await + .map_err(JobError::retry)?; info!(affected = n, "Killed all browser sessions for user"); let n = repo @@ -63,7 +67,8 @@ impl RunnableJob for DeactivateUserJob { &clock, OAuth2SessionFilter::new().for_user(&user).active_only(), ) - .await?; + .await + .map_err(JobError::retry)?; info!(affected = n, "Killed all OAuth 2.0 sessions for user"); let n = repo @@ -72,16 +77,20 @@ impl RunnableJob for DeactivateUserJob { &clock, CompatSessionFilter::new().for_user(&user).active_only(), ) - .await?; + .await + .map_err(JobError::retry)?; info!(affected = n, "Killed all compatibility sessions for user"); // Before calling back to the homeserver, commit the changes to the database, as // we want the user to be locked out as soon as possible - repo.save().await?; + repo.save().await.map_err(JobError::retry)?; let mxid = matrix.mxid(&user.username); info!("Deactivating user {} on homeserver", mxid); - matrix.delete_user(&mxid, self.hs_erase()).await?; + matrix + .delete_user(&mxid, self.hs_erase()) + .await + .map_err(JobError::retry)?; Ok(()) } @@ -94,26 +103,31 @@ impl RunnableJob for ReactivateUserJob { name = "job.reactivate_user", fields(user.id = %self.user_id()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; let mxid = matrix.mxid(&user.username); info!("Reactivating user {} on homeserver", mxid); - matrix.reactivate_user(&mxid).await?; + matrix + .reactivate_user(&mxid) + .await + .map_err(JobError::retry)?; // We want to unlock the user from our side only once it has been reactivated on // the homeserver - let _user = repo.user().unlock(user).await?; - repo.save().await?; + let _user = repo.user().unlock(user).await.map_err(JobError::retry)?; + repo.save().await.map_err(JobError::retry)?; Ok(()) } From 390a65c66f68644b308851e39af3c38c6ae78cfd Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 20 Nov 2024 18:36:45 +0100 Subject: [PATCH 13/17] Retry failed jobs --- ...6c35c9c236ea8beb6696e5740fa45655e59f3.json | 15 +++ ...a8e4d1682263079ec09c38a20c059580adb38.json | 16 +++ ...1388d6723f82549d88d704d9c939b9d35c49.json} | 10 +- ...ca42c790c101a3fc9442862b5885d5116325a.json | 16 +++ .../20241120163320_queue_job_failures.sql | 17 +++ crates/storage-pg/src/queue/job.rs | 111 +++++++++++++++++- crates/storage/src/queue/job.rs | 54 ++++++++- crates/tasks/src/new_queue.rs | 94 +++++++++++---- 8 files changed, 303 insertions(+), 30 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3.json create mode 100644 crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json rename crates/storage-pg/.sqlx/{query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json => query-707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49.json} (87%) create mode 100644 crates/storage-pg/.sqlx/query-f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a.json create mode 100644 crates/storage-pg/migrations/20241120163320_queue_job_failures.sql diff --git a/crates/storage-pg/.sqlx/query-07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3.json b/crates/storage-pg/.sqlx/query-07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3.json new file mode 100644 index 000000000..e5ffe95e2 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_jobs\n SET next_attempt_id = $1\n WHERE queue_job_id = $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3" +} diff --git a/crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json b/crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json new file mode 100644 index 000000000..2962db553 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at, attempt)\n SELECT $1, queue_name, payload, metadata, $2, attempt + 1\n FROM queue_jobs\n WHERE queue_job_id = $3\n AND status = 'failed'\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38" +} diff --git a/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json b/crates/storage-pg/.sqlx/query-707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49.json similarity index 87% rename from crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json rename to crates/storage-pg/.sqlx/query-707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49.json index 67f1ad132..88eb81f9f 100644 --- a/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json +++ b/crates/storage-pg/.sqlx/query-707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.queue_name,\n queue_jobs.payload,\n queue_jobs.metadata\n ", + "query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.queue_name,\n queue_jobs.payload,\n queue_jobs.metadata,\n queue_jobs.attempt\n ", "describe": { "columns": [ { @@ -22,6 +22,11 @@ "ordinal": 3, "name": "metadata", "type_info": "Jsonb" + }, + { + "ordinal": 4, + "name": "attempt", + "type_info": "Int4" } ], "parameters": { @@ -36,8 +41,9 @@ false, false, false, + false, false ] }, - "hash": "9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061" + "hash": "707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49" } diff --git a/crates/storage-pg/.sqlx/query-f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a.json b/crates/storage-pg/.sqlx/query-f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a.json new file mode 100644 index 000000000..df75b11b1 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_jobs\n SET\n status = 'failed',\n failed_at = $1,\n failed_reason = $2\n WHERE\n queue_job_id = $3\n AND status = 'running'\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Text", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a" +} diff --git a/crates/storage-pg/migrations/20241120163320_queue_job_failures.sql b/crates/storage-pg/migrations/20241120163320_queue_job_failures.sql new file mode 100644 index 000000000..0407d6342 --- /dev/null +++ b/crates/storage-pg/migrations/20241120163320_queue_job_failures.sql @@ -0,0 +1,17 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add a new status for failed jobs +ALTER TYPE "queue_job_status" ADD VALUE 'failed'; + +ALTER TABLE "queue_jobs" + -- When the job failed + ADD COLUMN "failed_at" TIMESTAMP WITH TIME ZONE, + -- Error message of the failure + ADD COLUMN "failed_reason" TEXT, + -- How many times we've already tried to run the job + ADD COLUMN "attempt" INTEGER NOT NULL DEFAULT 0, + -- The next attempt, if it was retried + ADD COLUMN "next_attempt_id" UUID REFERENCES "queue_jobs" ("queue_job_id"); diff --git a/crates/storage-pg/src/queue/job.rs b/crates/storage-pg/src/queue/job.rs index 02ceed793..769a9eb49 100644 --- a/crates/storage-pg/src/queue/job.rs +++ b/crates/storage-pg/src/queue/job.rs @@ -37,6 +37,7 @@ struct JobReservationResult { queue_name: String, payload: serde_json::Value, metadata: serde_json::Value, + attempt: i32, } impl TryFrom for Job { @@ -54,11 +55,19 @@ impl TryFrom for Job { .source(e) })?; + let attempt = value.attempt.try_into().map_err(|e| { + DatabaseInconsistencyError::on("queue_jobs") + .column("attempt") + .row(id) + .source(e) + })?; + Ok(Self { id, queue_name, payload, metadata, + attempt, }) } } @@ -152,7 +161,8 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { queue_jobs.queue_job_id, queue_jobs.queue_name, queue_jobs.payload, - queue_jobs.metadata + queue_jobs.metadata, + queue_jobs.attempt "#, &queues, max_count, @@ -199,4 +209,103 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { Ok(()) } + + #[tracing::instrument( + name = "db.queue_job.mark_as_failed", + skip_all, + fields( + db.query.text, + job.id = %id, + ), + err + )] + async fn mark_as_failed( + &mut self, + clock: &dyn Clock, + id: Ulid, + reason: &str, + ) -> Result<(), Self::Error> { + let now = clock.now(); + let res = sqlx::query!( + r#" + UPDATE queue_jobs + SET + status = 'failed', + failed_at = $1, + failed_reason = $2 + WHERE + queue_job_id = $3 + AND status = 'running' + "#, + now, + reason, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.queue_job.retry", + skip_all, + fields( + db.query.text, + job.id = %id, + ), + err + )] + async fn retry( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + id: Ulid, + ) -> Result<(), Self::Error> { + let now = clock.now(); + let new_id = Ulid::from_datetime_with_source(now.into(), rng); + + // Create a new job with the same payload and metadata, but a new ID and + // increment the attempt + // We make sure we do this only for 'failed' jobs + let res = sqlx::query!( + r#" + INSERT INTO queue_jobs + (queue_job_id, queue_name, payload, metadata, created_at, attempt) + SELECT $1, queue_name, payload, metadata, $2, attempt + 1 + FROM queue_jobs + WHERE queue_job_id = $3 + AND status = 'failed' + "#, + Uuid::from(new_id), + now, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + // Update the old job to point to the new attempt + let res = sqlx::query!( + r#" + UPDATE queue_jobs + SET next_attempt_id = $1 + WHERE queue_job_id = $2 + "#, + Uuid::from(new_id), + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } } diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs index 13df586d7..9a24fa649 100644 --- a/crates/storage/src/queue/job.rs +++ b/crates/storage/src/queue/job.rs @@ -28,6 +28,9 @@ pub struct Job { /// Arbitrary metadata about the job pub metadata: JobMetadata, + + /// Which attempt it is + pub attempt: usize, } /// Metadata stored alongside the job @@ -127,12 +130,48 @@ pub trait QueueJobRepository: Send + Sync { /// # Parameters /// /// * `clock` - The clock used to generate timestamps - /// * `job` - The job to mark as completed + /// * `id` - The ID of the job to mark as completed /// /// # Errors /// /// Returns an error if the underlying repository fails. async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; + + /// Marks a job as failed. + /// + /// # Parameters + /// + /// * `clock` - The clock used to generate timestamps + /// * `id` - The ID of the job to mark as failed + /// * `reason` - The reason for the failure + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn mark_as_failed( + &mut self, + clock: &dyn Clock, + id: Ulid, + reason: &str, + ) -> Result<(), Self::Error>; + + /// Retry a job. + /// + /// # Parameters + /// + /// * `rng` - The random number generator used to generate a new job ID + /// * `clock` - The clock used to generate timestamps + /// * `id` - The ID of the job to reschedule + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn retry( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + id: Ulid, + ) -> Result<(), Self::Error>; } repository_impl!(QueueJobRepository: @@ -154,6 +193,19 @@ repository_impl!(QueueJobRepository: ) -> Result, Self::Error>; async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; + + async fn mark_as_failed(&mut self, + clock: &dyn Clock, + id: Ulid, + reason: &str, + ) -> Result<(), Self::Error>; + + async fn retry( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + id: Ulid, + ) -> Result<(), Self::Error>; ); /// Extension trait for [`QueueJobRepository`] to help adding a job to the queue diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index ba707cff8..143b83ece 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -35,6 +35,7 @@ pub struct JobContext { pub id: Ulid, pub metadata: JobMetadata, pub queue_name: String, + pub attempt: usize, pub cancellation_token: CancellationToken, } @@ -156,6 +157,9 @@ const MAX_CONCURRENT_JOBS: usize = 10; // How many jobs can we fetch at once const MAX_JOBS_TO_FETCH: usize = 5; +// How many attempts a job should be retried +const MAX_ATTEMPTS: usize = 5; + type JobFactory = Arc Box + Send + Sync>; pub struct QueueWorker { @@ -280,6 +284,8 @@ impl QueueWorker { async fn shutdown(&mut self) -> Result<(), QueueRunnerError> { tracing::info!("Shutting down worker"); + // TODO: collect running jobs + // Start a transaction on the existing PgListener connection let txn = self .listener @@ -397,7 +403,7 @@ impl QueueWorker { while let Some(result) = self.last_join_result.take() { // TODO: add metrics to track the job status and the time it took - let context = match result { + match result { Ok((id, Ok(()))) => { // The job succeeded let context = self @@ -408,10 +414,13 @@ impl QueueWorker { tracing::info!( job.id = %context.id, job.queue_name = %context.queue_name, + job.attempt = %context.attempt, "Job completed" ); - context + repo.queue_job() + .mark_as_completed(&self.clock, context.id) + .await?; } Ok((id, Err(e))) => { // The job failed @@ -420,29 +429,48 @@ impl QueueWorker { .remove(&id) .expect("Job context not found"); + let reason = format!("{:?}", e.error); + repo.queue_job() + .mark_as_failed(&self.clock, context.id, &reason) + .await?; + match e.decision { JobErrorDecision::Fail => { tracing::error!( error = &e as &dyn std::error::Error, job.id = %context.id, job.queue_name = %context.queue_name, - "Job failed" + job.attempt = %context.attempt, + "Job failed, not retrying" ); } JobErrorDecision::Retry => { - tracing::warn!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - "Job failed, will retry" - ); - - // TODO: reschedule the job + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed, will retry" + ); + + // TODO: retry with an exponential backoff, once we know how to + // schedule jobs in the future + repo.queue_job() + .retry(&mut self.rng, &self.clock, context.id) + .await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed too many times, abandonning" + ); + } } } - - context } Err(e) => { // The job crashed (or was cancelled) @@ -452,23 +480,35 @@ impl QueueWorker { .remove(&id) .expect("Job context not found"); - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - "Job crashed" - ); - - // TODO: reschedule the job + let reason = e.to_string(); + repo.queue_job() + .mark_as_failed(&self.clock, context.id, &reason) + .await?; + + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed, will retry" + ); - context + repo.queue_job() + .retry(&mut self.rng, &self.clock, context.id) + .await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed too many times, abandonning" + ); + } } }; - repo.queue_job() - .mark_as_completed(&self.clock, context.id) - .await?; - self.last_join_result = self.running_jobs.try_join_next_with_id(); } @@ -492,6 +532,7 @@ impl QueueWorker { queue_name, payload, metadata, + attempt, } in jobs { let cancellation_token = self.cancellation_token.child_token(); @@ -500,6 +541,7 @@ impl QueueWorker { id, metadata, queue_name, + attempt, cancellation_token, }; From 6d92fb1c2aa8e00175afbbaa4bb418888fac5dd4 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 21 Nov 2024 10:55:45 +0100 Subject: [PATCH 14/17] Refactor job processing to wait for them to finish on shutdown --- crates/tasks/src/new_queue.rs | 395 +++++++++++++++++++++------------- 1 file changed, 241 insertions(+), 154 deletions(-) diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index 143b83ece..ce8504116 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -12,7 +12,7 @@ use mas_storage::{ Clock, RepositoryAccess, RepositoryError, }; use mas_storage_pg::{DatabaseError, PgRepository}; -use rand::{distributions::Uniform, Rng}; +use rand::{distributions::Uniform, Rng, RngCore}; use rand_chacha::ChaChaRng; use serde::de::DeserializeOwned; use sqlx::{ @@ -160,6 +160,7 @@ const MAX_JOBS_TO_FETCH: usize = 5; // How many attempts a job should be retried const MAX_ATTEMPTS: usize = 5; +type JobResult = Result<(), JobError>; type JobFactory = Arc Box + Send + Sync>; pub struct QueueWorker { @@ -171,13 +172,7 @@ pub struct QueueWorker { last_heartbeat: DateTime, cancellation_token: CancellationToken, state: State, - running_jobs: JoinSet>, - job_contexts: HashMap, - factories: HashMap<&'static str, JobFactory>, - - #[allow(clippy::type_complexity)] - last_join_result: - Option), tokio::task::JoinError>>, + tracker: JobTracker, } impl QueueWorker { @@ -234,10 +229,7 @@ impl QueueWorker { last_heartbeat: now, cancellation_token, state, - job_contexts: HashMap::new(), - running_jobs: JoinSet::new(), - factories: HashMap::new(), - last_join_result: None, + tracker: JobTracker::default(), }) } @@ -248,7 +240,9 @@ impl QueueWorker { box_runnable_job(T::from_job(payload).expect("Failed to deserialize job")) }; - self.factories.insert(T::QUEUE_NAME, Arc::new(factory)); + self.tracker + .factories + .insert(T::QUEUE_NAME, Arc::new(factory)); self } @@ -266,7 +260,6 @@ impl QueueWorker { async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { self.wait_until_wakeup().await?; - // TODO: join all the jobs handles when shutting down if self.cancellation_token.is_cancelled() { return Ok(()); } @@ -284,8 +277,6 @@ impl QueueWorker { async fn shutdown(&mut self) -> Result<(), QueueRunnerError> { tracing::info!("Shutting down worker"); - // TODO: collect running jobs - // Start a transaction on the existing PgListener connection let txn = self .listener @@ -295,6 +286,24 @@ impl QueueWorker { let mut repo = PgRepository::from_conn(txn); + // Log about any job still running + match self.tracker.running_jobs() { + 0 => {} + 1 => tracing::warn!("There is one job still running, waiting for it to finish"), + n => tracing::warn!("There are {n} jobs still running, waiting for them to finish"), + } + + // TODO: we may want to introduce a timeout here, and abort the tasks if they + // take too long. It's fine for now, as we don't have long-running + // tasks, most of them are idempotent, and the only effect might be that + // the worker would 'dirtily' shutdown, meaning that its tasks would be + // considered, later retried by another worker + + // Wait for all the jobs to finish + self.tracker + .process_jobs(&mut self.rng, &self.clock, &mut repo, true) + .await?; + // Tell the other workers we're shutting down // This also releases the leader election lease repo.queue_worker() @@ -330,9 +339,8 @@ impl QueueWorker { tracing::debug!("Woke up from sleep"); }, - Some(result) = self.running_jobs.join_next_with_id() => { + () = self.tracker.collect_next_job(), if self.tracker.has_jobs() => { tracing::debug!("Joined job task"); - self.last_join_result = Some(result); }, notification = self.listener.recv() => { @@ -393,135 +401,21 @@ impl QueueWorker { .try_get_leader_lease(&self.clock, &self.registration) .await?; - // Find any job task which finished - // If we got woken up by a join on the joinset, it will be stored in the - // last_join_result so that we don't loose it - - if self.last_join_result.is_none() { - self.last_join_result = self.running_jobs.try_join_next_with_id(); - } - - while let Some(result) = self.last_join_result.take() { - // TODO: add metrics to track the job status and the time it took - match result { - Ok((id, Ok(()))) => { - // The job succeeded - let context = self - .job_contexts - .remove(&id) - .expect("Job context not found"); - - tracing::info!( - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job completed" - ); - - repo.queue_job() - .mark_as_completed(&self.clock, context.id) - .await?; - } - Ok((id, Err(e))) => { - // The job failed - let context = self - .job_contexts - .remove(&id) - .expect("Job context not found"); - - let reason = format!("{:?}", e.error); - repo.queue_job() - .mark_as_failed(&self.clock, context.id, &reason) - .await?; - - match e.decision { - JobErrorDecision::Fail => { - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job failed, not retrying" - ); - } - - JobErrorDecision::Retry => { - if context.attempt < MAX_ATTEMPTS { - tracing::warn!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job failed, will retry" - ); - - // TODO: retry with an exponential backoff, once we know how to - // schedule jobs in the future - repo.queue_job() - .retry(&mut self.rng, &self.clock, context.id) - .await?; - } else { - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job failed too many times, abandonning" - ); - } - } - } - } - Err(e) => { - // The job crashed (or was cancelled) - let id = e.id(); - let context = self - .job_contexts - .remove(&id) - .expect("Job context not found"); - - let reason = e.to_string(); - repo.queue_job() - .mark_as_failed(&self.clock, context.id, &reason) - .await?; - - if context.attempt < MAX_ATTEMPTS { - tracing::warn!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job crashed, will retry" - ); - - repo.queue_job() - .retry(&mut self.rng, &self.clock, context.id) - .await?; - } else { - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job crashed too many times, abandonning" - ); - } - } - }; - - self.last_join_result = self.running_jobs.try_join_next_with_id(); - } + // Process any job task which finished + self.tracker + .process_jobs(&mut self.rng, &self.clock, &mut repo, false) + .await?; // Compute how many jobs we should fetch at most let max_jobs_to_fetch = MAX_CONCURRENT_JOBS - .saturating_sub(self.running_jobs.len()) + .saturating_sub(self.tracker.running_jobs()) .max(MAX_JOBS_TO_FETCH); if max_jobs_to_fetch == 0 { tracing::warn!("Internal job queue is full, not fetching any new jobs"); } else { // Grab a few jobs in the queue - let queues = self.factories.keys().copied().collect::>(); + let queues = self.tracker.queues(); let jobs = repo .queue_job() .reserve(&self.clock, &self.registration, &queues, max_jobs_to_fetch) @@ -536,7 +430,6 @@ impl QueueWorker { } in jobs { let cancellation_token = self.cancellation_token.child_token(); - let factory = self.factories.get(queue_name.as_str()).cloned(); let context = JobContext { id, metadata, @@ -545,21 +438,7 @@ impl QueueWorker { cancellation_token, }; - let task = { - let context = context.clone(); - let span = context.span(); - let state = self.state.clone(); - async move { - // We should never crash, but in case we do, we do that in the task and - // don't crash the worker - let job = factory.expect("unknown job factory")(payload); - job.run(&state, context).await - } - .instrument(span) - }; - - let handle = self.running_jobs.spawn(task); - self.job_contexts.insert(handle.id(), context); + self.tracker.spawn_job(self.state.clone(), context, payload); } } @@ -651,3 +530,211 @@ impl QueueWorker { Ok(()) } } + +/// Tracks running jobs +/// +/// This is a separate structure to be able to borrow it mutably at the same +/// time as the connection to the database is borrowed +#[derive(Default)] +struct JobTracker { + /// Stores a mapping from the job queue name to the job factory + factories: HashMap<&'static str, JobFactory>, + + /// A join set of all the currently running jobs + running_jobs: JoinSet, + + /// Stores a mapping from the Tokio task ID to the job context + job_contexts: HashMap, + + /// Stores the last `join_next_with_id` result for processing, in case we + /// got woken up in `collect_next_job` + last_join_result: Option>, +} + +impl JobTracker { + /// Returns the queue names that are currently being tracked + fn queues(&self) -> Vec<&'static str> { + self.factories.keys().copied().collect() + } + + /// Spawn a job on the job tracker + fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) { + let factory = self.factories.get(context.queue_name.as_str()).cloned(); + let task = { + let context = context.clone(); + let span = context.span(); + async move { + // We should never crash, but in case we do, we do that in the task and + // don't crash the worker + let job = factory.expect("unknown job factory")(payload); + tracing::info!("Running job"); + job.run(&state, context).await + } + .instrument(span) + }; + + let handle = self.running_jobs.spawn(task); + self.job_contexts.insert(handle.id(), context); + } + + /// Returns `true` if there are currently running jobs + fn has_jobs(&self) -> bool { + !self.running_jobs.is_empty() + } + + /// Returns the number of currently running jobs + /// + /// This also includes the job result which may be stored for processing + fn running_jobs(&self) -> usize { + self.running_jobs.len() + usize::from(self.last_join_result.is_some()) + } + + async fn collect_next_job(&mut self) { + // Double-check that we don't have a job result stored + if self.last_join_result.is_some() { + tracing::error!( + "Job tracker already had a job result stored, this should never happen!" + ); + return; + } + + self.last_join_result = self.running_jobs.join_next_with_id().await; + } + + /// Process all the jobs which are currently running + /// + /// If `blocking` is `true`, this function will block until all the jobs + /// are finished. Otherwise, it will return as soon as it processed the + /// already finished jobs. + #[allow(clippy::too_many_lines)] + async fn process_jobs( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + repo: &mut dyn RepositoryAccess, + blocking: bool, + ) -> Result<(), E> { + if self.last_join_result.is_none() { + if blocking { + self.last_join_result = self.running_jobs.join_next_with_id().await; + } else { + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + } + + while let Some(result) = self.last_join_result.take() { + // TODO: add metrics to track the job status and the time it took + match result { + // The job succeeded + Ok((id, Ok(()))) => { + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::info!( + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job completed" + ); + + repo.queue_job() + .mark_as_completed(clock, context.id) + .await?; + } + + // The job failed + Ok((id, Err(e))) => { + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + let reason = format!("{:?}", e.error); + repo.queue_job() + .mark_as_failed(clock, context.id, &reason) + .await?; + + match e.decision { + JobErrorDecision::Fail => { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed, not retrying" + ); + } + + JobErrorDecision::Retry => { + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed, will retry" + ); + + // TODO: retry with an exponential backoff, once we know how to + // schedule jobs in the future + repo.queue_job().retry(&mut *rng, clock, context.id).await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed too many times, abandonning" + ); + } + } + } + } + + // The job crashed (or was aborted) + Err(e) => { + let id = e.id(); + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + let reason = e.to_string(); + repo.queue_job() + .mark_as_failed(clock, context.id, &reason) + .await?; + + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed, will retry" + ); + + repo.queue_job().retry(&mut *rng, clock, context.id).await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed too many times, abandonning" + ); + } + } + }; + + if blocking { + self.last_join_result = self.running_jobs.join_next_with_id().await; + } else { + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + } + + Ok(()) + } +} From 7b120f143751ac1efcedf1a146b5a7f994a637cf Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 22 Nov 2024 16:58:38 +0100 Subject: [PATCH 15/17] Allow scheduling jobs in the future Also retries jobs with an exponential backoff. --- ...ce891e626a82dcb78ff85f2b815d9329ff936.json | 17 ++++ ...b2e83858c9944893b8f3a0f0131e8a9b7a494.json | 14 +++ ...a8e4d1682263079ec09c38a20c059580adb38.json | 16 --- ...59e9fc0bf8a1fe9002dc3854ae28e65fc7943.json | 19 ++++ .../20241122130349_queue_job_scheduled.sql | 11 +++ ...241122133435_queue_job_scheduled_index.sql | 9 ++ crates/storage-pg/src/queue/job.rs | 82 +++++++++++++++- crates/storage/src/queue/job.rs | 97 +++++++++++++++++++ crates/tasks/src/new_queue.rs | 37 +++++-- 9 files changed, 277 insertions(+), 25 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936.json create mode 100644 crates/storage-pg/.sqlx/query-3c7960a2eb2edd71bc71177fc0fb2e83858c9944893b8f3a0f0131e8a9b7a494.json delete mode 100644 crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json create mode 100644 crates/storage-pg/.sqlx/query-d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943.json create mode 100644 crates/storage-pg/migrations/20241122130349_queue_job_scheduled.sql create mode 100644 crates/storage-pg/migrations/20241122133435_queue_job_scheduled_index.sql diff --git a/crates/storage-pg/.sqlx/query-3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936.json b/crates/storage-pg/.sqlx/query-3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936.json new file mode 100644 index 000000000..c65354f92 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at, attempt, scheduled_at, status)\n SELECT $1, queue_name, payload, metadata, $2, attempt + 1, $3, 'scheduled'\n FROM queue_jobs\n WHERE queue_job_id = $4\n AND status = 'failed'\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936" +} diff --git a/crates/storage-pg/.sqlx/query-3c7960a2eb2edd71bc71177fc0fb2e83858c9944893b8f3a0f0131e8a9b7a494.json b/crates/storage-pg/.sqlx/query-3c7960a2eb2edd71bc71177fc0fb2e83858c9944893b8f3a0f0131e8a9b7a494.json new file mode 100644 index 000000000..a45aacc7a --- /dev/null +++ b/crates/storage-pg/.sqlx/query-3c7960a2eb2edd71bc71177fc0fb2e83858c9944893b8f3a0f0131e8a9b7a494.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_jobs\n SET status = 'available'\n WHERE\n status = 'scheduled'\n AND scheduled_at <= $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "3c7960a2eb2edd71bc71177fc0fb2e83858c9944893b8f3a0f0131e8a9b7a494" +} diff --git a/crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json b/crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json deleted file mode 100644 index 2962db553..000000000 --- a/crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at, attempt)\n SELECT $1, queue_name, payload, metadata, $2, attempt + 1\n FROM queue_jobs\n WHERE queue_job_id = $3\n AND status = 'failed'\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz", - "Uuid" - ] - }, - "nullable": [] - }, - "hash": "47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38" -} diff --git a/crates/storage-pg/.sqlx/query-d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943.json b/crates/storage-pg/.sqlx/query-d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943.json new file mode 100644 index 000000000..f87d2dff4 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943.json @@ -0,0 +1,19 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at, scheduled_at, status)\n VALUES ($1, $2, $3, $4, $5, $6, 'scheduled')\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Jsonb", + "Jsonb", + "Timestamptz", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943" +} diff --git a/crates/storage-pg/migrations/20241122130349_queue_job_scheduled.sql b/crates/storage-pg/migrations/20241122130349_queue_job_scheduled.sql new file mode 100644 index 000000000..e7aff6a04 --- /dev/null +++ b/crates/storage-pg/migrations/20241122130349_queue_job_scheduled.sql @@ -0,0 +1,11 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add a new status for scheduled jobs +ALTER TYPE "queue_job_status" ADD VALUE 'scheduled'; + +ALTER TABLE "queue_jobs" + -- When the job is scheduled to run + ADD COLUMN "scheduled_at" TIMESTAMP WITH TIME ZONE; diff --git a/crates/storage-pg/migrations/20241122133435_queue_job_scheduled_index.sql b/crates/storage-pg/migrations/20241122133435_queue_job_scheduled_index.sql new file mode 100644 index 000000000..f8a7422e2 --- /dev/null +++ b/crates/storage-pg/migrations/20241122133435_queue_job_scheduled_index.sql @@ -0,0 +1,9 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add a partial index on scheduled jobs +CREATE INDEX "queue_jobs_scheduled_at_idx" + ON "queue_jobs" ("scheduled_at") + WHERE "status" = 'scheduled'; diff --git a/crates/storage-pg/src/queue/job.rs b/crates/storage-pg/src/queue/job.rs index 769a9eb49..e6be988c9 100644 --- a/crates/storage-pg/src/queue/job.rs +++ b/crates/storage-pg/src/queue/job.rs @@ -7,6 +7,7 @@ //! [`QueueJobRepository`]. use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; use mas_storage::{ queue::{Job, QueueJobRepository, Worker}, Clock, @@ -117,6 +118,50 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { Ok(()) } + #[tracing::instrument( + name = "db.queue_job.schedule_later", + fields( + queue_job.id, + queue_job.queue_name = queue_name, + queue_job.scheduled_at = %scheduled_at, + db.query.text, + ), + skip_all, + err, + )] + async fn schedule_later( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + queue_name: &str, + payload: serde_json::Value, + metadata: serde_json::Value, + scheduled_at: DateTime, + ) -> Result<(), Self::Error> { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("queue_job.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO queue_jobs + (queue_job_id, queue_name, payload, metadata, created_at, scheduled_at, status) + VALUES ($1, $2, $3, $4, $5, $6, 'scheduled') + "#, + Uuid::from(id), + queue_name, + payload, + metadata, + created_at, + scheduled_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + #[tracing::instrument( name = "db.queue_job.reserve", skip_all, @@ -264,8 +309,10 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { rng: &mut (dyn RngCore + Send), clock: &dyn Clock, id: Ulid, + delay: Duration, ) -> Result<(), Self::Error> { let now = clock.now(); + let scheduled_at = now + delay; let new_id = Ulid::from_datetime_with_source(now.into(), rng); // Create a new job with the same payload and metadata, but a new ID and @@ -274,14 +321,15 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { let res = sqlx::query!( r#" INSERT INTO queue_jobs - (queue_job_id, queue_name, payload, metadata, created_at, attempt) - SELECT $1, queue_name, payload, metadata, $2, attempt + 1 + (queue_job_id, queue_name, payload, metadata, created_at, attempt, scheduled_at, status) + SELECT $1, queue_name, payload, metadata, $2, attempt + 1, $3, 'scheduled' FROM queue_jobs - WHERE queue_job_id = $3 + WHERE queue_job_id = $4 AND status = 'failed' "#, Uuid::from(new_id), now, + scheduled_at, Uuid::from(id), ) .traced() @@ -308,4 +356,32 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { Ok(()) } + + #[tracing::instrument( + name = "db.queue_job.schedule_available_jobs", + skip_all, + fields( + db.query.text, + ), + err + )] + async fn schedule_available_jobs(&mut self, clock: &dyn Clock) -> Result { + let now = clock.now(); + let res = sqlx::query!( + r#" + UPDATE queue_jobs + SET status = 'available' + WHERE + status = 'scheduled' + AND scheduled_at <= $1 + "#, + now, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let count = res.rows_affected(); + Ok(usize::try_from(count).unwrap_or(usize::MAX)) + } } diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs index 9a24fa649..5bbc5f75b 100644 --- a/crates/storage/src/queue/job.rs +++ b/crates/storage/src/queue/job.rs @@ -6,6 +6,7 @@ //! Repository to interact with jobs in the job queue use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; use opentelemetry::trace::TraceContextExt; use rand_core::RngCore; use serde::{Deserialize, Serialize}; @@ -105,6 +106,30 @@ pub trait QueueJobRepository: Send + Sync { metadata: serde_json::Value, ) -> Result<(), Self::Error>; + /// Schedule a job to be executed at a later date by a worker. + /// + /// # Parameters + /// + /// * `rng` - The random number generator used to generate a new job ID + /// * `clock` - The clock used to generate timestamps + /// * `queue_name` - The name of the queue to schedule the job on + /// * `payload` - The payload of the job + /// * `metadata` - Arbitrary metadata about the job scheduled immediately. + /// * `scheduled_at` - The date and time to schedule the job for + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn schedule_later( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + queue_name: &str, + payload: serde_json::Value, + metadata: serde_json::Value, + scheduled_at: DateTime, + ) -> Result<(), Self::Error>; + /// Reserve multiple jobs from multiple queues /// /// # Parameters @@ -171,7 +196,18 @@ pub trait QueueJobRepository: Send + Sync { rng: &mut (dyn RngCore + Send), clock: &dyn Clock, id: Ulid, + delay: Duration, ) -> Result<(), Self::Error>; + + /// Mark all scheduled jobs past their scheduled date as available to be + /// executed. + /// + /// Returns the number of jobs that were marked as available. + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn schedule_available_jobs(&mut self, clock: &dyn Clock) -> Result; } repository_impl!(QueueJobRepository: @@ -184,6 +220,16 @@ repository_impl!(QueueJobRepository: metadata: serde_json::Value, ) -> Result<(), Self::Error>; + async fn schedule_later( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + queue_name: &str, + payload: serde_json::Value, + metadata: serde_json::Value, + scheduled_at: DateTime, + ) -> Result<(), Self::Error>; + async fn reserve( &mut self, clock: &dyn Clock, @@ -205,7 +251,10 @@ repository_impl!(QueueJobRepository: rng: &mut (dyn RngCore + Send), clock: &dyn Clock, id: Ulid, + delay: Duration, ) -> Result<(), Self::Error>; + + async fn schedule_available_jobs(&mut self, clock: &dyn Clock) -> Result; ); /// Extension trait for [`QueueJobRepository`] to help adding a job to the queue @@ -230,6 +279,26 @@ pub trait QueueJobRepositoryExt: QueueJobRepository { clock: &dyn Clock, job: J, ) -> Result<(), Self::Error>; + + /// Schedule a job to be executed at a later date by a worker. + /// + /// # Parameters + /// + /// * `rng` - The random number generator used to generate a new job ID + /// * `clock` - The clock used to generate timestamps + /// * `job` - The job to schedule + /// * `scheduled_at` - The date and time to schedule the job for + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn schedule_job_later( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + job: J, + scheduled_at: DateTime, + ) -> Result<(), Self::Error>; } #[async_trait] @@ -263,4 +332,32 @@ where self.schedule(rng, clock, J::QUEUE_NAME, payload, metadata) .await } + + #[tracing::instrument( + name = "db.queue_job.schedule_job_later", + fields( + queue_job.queue_name = J::QUEUE_NAME, + ), + skip_all, + )] + async fn schedule_job_later( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + job: J, + scheduled_at: DateTime, + ) -> Result<(), Self::Error> { + // Grab the span context from the current span + let span = tracing::Span::current(); + let ctx = span.context(); + let span = ctx.span(); + let span_context = span.span_context(); + + let metadata = JobMetadata::new(span_context); + let metadata = serde_json::to_value(metadata).expect("Could not serialize metadata"); + + let payload = serde_json::to_value(job).expect("Could not serialize job"); + self.schedule_later(rng, clock, J::QUEUE_NAME, payload, metadata, scheduled_at) + .await + } } diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index ce8504116..fb2fa7151 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -160,6 +160,14 @@ const MAX_JOBS_TO_FETCH: usize = 5; // How many attempts a job should be retried const MAX_ATTEMPTS: usize = 5; +/// Returns the delay to wait before retrying a job +/// +/// Uses an exponential backoff: 1s, 2s, 4s, 8s, 16s +fn retry_delay(attempt: usize) -> Duration { + let attempt = u32::try_from(attempt).unwrap_or(u32::MAX); + Duration::milliseconds(2_i64.saturating_pow(attempt) * 1000) +} + type JobResult = Result<(), JobError>; type JobFactory = Arc Box + Send + Sync>; @@ -516,6 +524,17 @@ impl QueueWorker { // TODO: mark tasks those workers had as lost + // Mark all the scheduled jobs as available + let scheduled = repo + .queue_job() + .schedule_available_jobs(&self.clock) + .await?; + match scheduled { + 0 => {} + 1 => tracing::warn!("One scheduled job marked as available"), + n => tracing::warn!("{n} scheduled jobs marked as available"), + } + // Release the leader lock let txn = repo .into_inner() @@ -669,17 +688,19 @@ impl JobTracker { JobErrorDecision::Retry => { if context.attempt < MAX_ATTEMPTS { + let delay = retry_delay(context.attempt); tracing::warn!( error = &e as &dyn std::error::Error, job.id = %context.id, job.queue_name = %context.queue_name, job.attempt = %context.attempt, - "Job failed, will retry" + "Job failed, will retry in {}s", + delay.num_seconds() ); - // TODO: retry with an exponential backoff, once we know how to - // schedule jobs in the future - repo.queue_job().retry(&mut *rng, clock, context.id).await?; + repo.queue_job() + .retry(&mut *rng, clock, context.id, delay) + .await?; } else { tracing::error!( error = &e as &dyn std::error::Error, @@ -707,15 +728,19 @@ impl JobTracker { .await?; if context.attempt < MAX_ATTEMPTS { + let delay = retry_delay(context.attempt); tracing::warn!( error = &e as &dyn std::error::Error, job.id = %context.id, job.queue_name = %context.queue_name, job.attempt = %context.attempt, - "Job crashed, will retry" + "Job crashed, will retry in {}s", + delay.num_seconds() ); - repo.queue_job().retry(&mut *rng, clock, context.id).await?; + repo.queue_job() + .retry(&mut *rng, clock, context.id, delay) + .await?; } else { tracing::error!( error = &e as &dyn std::error::Error, From 8dbc9140d0bb3695cec6e85659b6b2af6b6555dc Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 25 Nov 2024 17:40:08 +0100 Subject: [PATCH 16/17] Cron-like recurring jobs --- Cargo.lock | 13 ++ Cargo.toml | 4 + .../20241125110803_queue_job_recurrent.sql | 29 +++++ crates/storage-pg/src/queue/job.rs | 74 +++++++++++- crates/storage-pg/src/queue/mod.rs | 1 + crates/storage-pg/src/queue/schedule.rs | 91 ++++++++++++++ crates/storage-pg/src/repository.rs | 20 ++-- crates/storage/Cargo.toml | 1 + crates/storage/src/queue/job.rs | 18 ++- crates/storage/src/queue/mod.rs | 2 + crates/storage/src/queue/schedule.rs | 59 +++++++++ crates/storage/src/queue/tasks.rs | 9 ++ crates/storage/src/repository.rs | 21 +++- crates/tasks/Cargo.toml | 1 + crates/tasks/src/database.rs | 93 ++++---------- crates/tasks/src/lib.rs | 26 ++-- crates/tasks/src/new_queue.rs | 113 +++++++++++++++++- 17 files changed, 479 insertions(+), 96 deletions(-) create mode 100644 crates/storage-pg/migrations/20241125110803_queue_job_recurrent.sql create mode 100644 crates/storage-pg/src/queue/schedule.rs create mode 100644 crates/storage/src/queue/schedule.rs diff --git a/Cargo.lock b/Cargo.lock index 71b98659a..a05b843f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1314,6 +1314,17 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "cron" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eee8b2b4516038bc0f1d3c9934bcb4a13dd316e04abbc63c96757a6d75978532" +dependencies = [ + "chrono", + "nom", + "once_cell", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -3625,6 +3636,7 @@ version = "0.12.0" dependencies = [ "async-trait", "chrono", + "cron", "futures-util", "mas-data-model", "mas-iana", @@ -3676,6 +3688,7 @@ dependencies = [ "async-stream", "async-trait", "chrono", + "cron", "event-listener 5.3.1", "futures-lite", "mas-data-model", diff --git a/Cargo.toml b/Cargo.toml index 8d14cafde..a7309ed03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,6 +102,10 @@ features = ["serde", "clock"] version = "4.5.21" features = ["derive"] +# Cron expressions +[workspace.dependencies.cron] +version = "0.13.0" + # Elliptic curve cryptography [workspace.dependencies.elliptic-curve] version = "0.13.8" diff --git a/crates/storage-pg/migrations/20241125110803_queue_job_recurrent.sql b/crates/storage-pg/migrations/20241125110803_queue_job_recurrent.sql new file mode 100644 index 000000000..18c28803c --- /dev/null +++ b/crates/storage-pg/migrations/20241125110803_queue_job_recurrent.sql @@ -0,0 +1,29 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add a table to track the state of scheduled recurring jobs. +CREATE TABLE queue_schedules ( + -- A unique name for the schedule + schedule_name TEXT PRIMARY KEY, + + -- The cron expression to use to schedule the job. This is there just for + -- convenience, as this is defined by the backend + schedule_expression TEXT NOT NULL, + + -- The last time the job was scheduled. If NULL, it means that the job was + -- never scheduled. + last_scheduled_at TIMESTAMP WITH TIME ZONE, + + -- The job that was scheduled last time. If NULL, it means that either the + -- job was never scheduled, or the job cleaned up from the database + last_scheduled_job_id UUID + REFERENCES queue_jobs (queue_job_id) +); + +-- When a job is scheduled from a recurreing schedule, we keep a column +-- referencing the name of the schedule +ALTER TABLE queue_jobs + ADD COLUMN schedule_name TEXT + REFERENCES queue_schedules (schedule_name); diff --git a/crates/storage-pg/src/queue/job.rs b/crates/storage-pg/src/queue/job.rs index e6be988c9..be7ed6bb2 100644 --- a/crates/storage-pg/src/queue/job.rs +++ b/crates/storage-pg/src/queue/job.rs @@ -12,8 +12,10 @@ use mas_storage::{ queue::{Job, QueueJobRepository, Worker}, Clock, }; +use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT; use rand::RngCore; use sqlx::PgConnection; +use tracing::Instrument; use ulid::Ulid; use uuid::Uuid; @@ -137,6 +139,7 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { payload: serde_json::Value, metadata: serde_json::Value, scheduled_at: DateTime, + schedule_name: Option<&str>, ) -> Result<(), Self::Error> { let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), rng); @@ -145,8 +148,8 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { sqlx::query!( r#" INSERT INTO queue_jobs - (queue_job_id, queue_name, payload, metadata, created_at, scheduled_at, status) - VALUES ($1, $2, $3, $4, $5, $6, 'scheduled') + (queue_job_id, queue_name, payload, metadata, created_at, scheduled_at, schedule_name, status) + VALUES ($1, $2, $3, $4, $5, $6, $7, 'scheduled') "#, Uuid::from(id), queue_name, @@ -154,11 +157,38 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { metadata, created_at, scheduled_at, + schedule_name, ) .traced() .execute(&mut *self.conn) .await?; + // If there was a schedule name supplied, update the queue_schedules table + if let Some(schedule_name) = schedule_name { + let span = tracing::info_span!( + "db.queue_job.schedule_later.update_schedules", + { DB_QUERY_TEXT } = tracing::field::Empty, + ); + + let res = sqlx::query!( + r#" + UPDATE queue_schedules + SET last_scheduled_at = $1, + last_scheduled_job_id = $2 + WHERE schedule_name = $3 + "#, + scheduled_at, + Uuid::from(id), + schedule_name, + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + } + Ok(()) } @@ -315,14 +345,19 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { let scheduled_at = now + delay; let new_id = Ulid::from_datetime_with_source(now.into(), rng); + let span = tracing::info_span!( + "db.queue_job.retry.insert_job", + { DB_QUERY_TEXT } = tracing::field::Empty + ); // Create a new job with the same payload and metadata, but a new ID and // increment the attempt // We make sure we do this only for 'failed' jobs let res = sqlx::query!( r#" INSERT INTO queue_jobs - (queue_job_id, queue_name, payload, metadata, created_at, attempt, scheduled_at, status) - SELECT $1, queue_name, payload, metadata, $2, attempt + 1, $3, 'scheduled' + (queue_job_id, queue_name, payload, metadata, created_at, + attempt, scheduled_at, schedule_name, status) + SELECT $1, queue_name, payload, metadata, $2, attempt + 1, $3, schedule_name, 'scheduled' FROM queue_jobs WHERE queue_job_id = $4 AND status = 'failed' @@ -332,13 +367,39 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { scheduled_at, Uuid::from(id), ) - .traced() + .record(&span) .execute(&mut *self.conn) + .instrument(span) .await?; DatabaseError::ensure_affected_rows(&res, 1)?; + // If that job was referenced by a schedule, update the schedule + let span = tracing::info_span!( + "db.queue_job.retry.update_schedule", + { DB_QUERY_TEXT } = tracing::field::Empty + ); + sqlx::query!( + r#" + UPDATE queue_schedules + SET last_scheduled_at = $1, + last_scheduled_job_id = $2 + WHERE last_scheduled_job_id = $3 + "#, + scheduled_at, + Uuid::from(new_id), + Uuid::from(id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + // Update the old job to point to the new attempt + let span = tracing::info_span!( + "db.queue_job.retry.update_old_job", + { DB_QUERY_TEXT } = tracing::field::Empty + ); let res = sqlx::query!( r#" UPDATE queue_jobs @@ -348,8 +409,9 @@ impl QueueJobRepository for PgQueueJobRepository<'_> { Uuid::from(new_id), Uuid::from(id), ) - .traced() + .record(&span) .execute(&mut *self.conn) + .instrument(span) .await?; DatabaseError::ensure_affected_rows(&res, 1)?; diff --git a/crates/storage-pg/src/queue/mod.rs b/crates/storage-pg/src/queue/mod.rs index eca02b809..1c00e1d7d 100644 --- a/crates/storage-pg/src/queue/mod.rs +++ b/crates/storage-pg/src/queue/mod.rs @@ -6,4 +6,5 @@ //! A module containing the PostgreSQL implementation of the job queue pub mod job; +pub mod schedule; pub mod worker; diff --git a/crates/storage-pg/src/queue/schedule.rs b/crates/storage-pg/src/queue/schedule.rs new file mode 100644 index 000000000..41f4cb7cf --- /dev/null +++ b/crates/storage-pg/src/queue/schedule.rs @@ -0,0 +1,91 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! A module containing the PostgreSQL implementation of the +//! [`QueueScheduleRepository`]. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_storage::queue::{QueueScheduleRepository, Schedule, ScheduleStatus}; +use sqlx::PgConnection; + +use crate::{DatabaseError, ExecuteExt}; + +/// An implementation of [`QueueScheduleRepository`] for a PostgreSQL +/// connection. +pub struct PgQueueScheduleRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgQueueScheduleRepository<'c> { + /// Create a new [`PgQueueScheduleRepository`] from an active PostgreSQL + /// connection. + #[must_use] + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct ScheduleLookup { + schedule_name: String, + last_scheduled_at: Option>, + last_scheduled_job_completed: Option, +} + +impl From for ScheduleStatus { + fn from(value: ScheduleLookup) -> Self { + ScheduleStatus { + schedule_name: value.schedule_name, + last_scheduled_at: value.last_scheduled_at, + last_scheduled_job_completed: value.last_scheduled_job_completed, + } + } +} + +#[async_trait] +impl<'c> QueueScheduleRepository for PgQueueScheduleRepository<'c> { + type Error = DatabaseError; + + async fn setup(&mut self, schedules: &[(&'static str, Schedule)]) -> Result<(), Self::Error> { + sqlx::query!( + r#" + INSERT INTO queue_schedules (schedule_name, schedule_expression) + SELECT * FROM UNNEST($1::text[], $2::text[]) AS t (schedule_name, schedule_expression) + ON CONFLICT (schedule_name) DO UPDATE + SET schedule_expression = EXCLUDED.schedule_expression + "#, + &schedules.iter().map(|(name, _)| (*name).to_owned()).collect::>(), + &schedules + .iter() + .map(|(_, schedule)| schedule.source().to_owned()) + .collect::>() + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + async fn list(&mut self) -> Result, Self::Error> { + let res = sqlx::query_as!( + ScheduleLookup, + r#" + SELECT + queue_schedules.schedule_name as "schedule_name!", + queue_schedules.last_scheduled_at, + queue_jobs.status IN ('completed', 'failed') as last_scheduled_job_completed + FROM queue_schedules + LEFT JOIN queue_jobs + ON queue_jobs.queue_job_id = queue_schedules.last_scheduled_job_id + "# + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + Ok(res.into_iter().map(Into::into).collect()) + } +} diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index b5c2b68b2..923221742 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -17,6 +17,7 @@ use mas_storage::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, + queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository}, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, @@ -38,7 +39,10 @@ use crate::{ PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, }, - queue::{job::PgQueueJobRepository, worker::PgQueueWorkerRepository}, + queue::{ + job::PgQueueJobRepository, schedule::PgQueueScheduleRepository, + worker::PgQueueWorkerRepository, + }, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, @@ -259,15 +263,17 @@ where Box::new(PgCompatRefreshTokenRepository::new(self.conn.as_mut())) } - fn queue_worker<'c>( - &'c mut self, - ) -> Box + 'c> { + fn queue_worker<'c>(&'c mut self) -> Box + 'c> { Box::new(PgQueueWorkerRepository::new(self.conn.as_mut())) } - fn queue_job<'c>( - &'c mut self, - ) -> Box + 'c> { + fn queue_job<'c>(&'c mut self) -> Box + 'c> { Box::new(PgQueueJobRepository::new(self.conn.as_mut())) } + + fn queue_schedule<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgQueueScheduleRepository::new(self.conn.as_mut())) + } } diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index 22d209df0..97e06b507 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] async-trait.workspace = true chrono.workspace = true +cron.workspace = true futures-util.workspace = true opentelemetry.workspace = true rand_core = "0.6.4" diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs index 5bbc5f75b..e4c9f7235 100644 --- a/crates/storage/src/queue/job.rs +++ b/crates/storage/src/queue/job.rs @@ -7,6 +7,7 @@ use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; +use cron::Schedule; use opentelemetry::trace::TraceContextExt; use rand_core::RngCore; use serde::{Deserialize, Serialize}; @@ -116,10 +117,13 @@ pub trait QueueJobRepository: Send + Sync { /// * `payload` - The payload of the job /// * `metadata` - Arbitrary metadata about the job scheduled immediately. /// * `scheduled_at` - The date and time to schedule the job for + /// * `schedule_name` - The name of the recurring schedule which scheduled + /// this job /// /// # Errors /// /// Returns an error if the underlying repository fails. + #[allow(clippy::too_many_arguments)] async fn schedule_later( &mut self, rng: &mut (dyn RngCore + Send), @@ -128,6 +132,7 @@ pub trait QueueJobRepository: Send + Sync { payload: serde_json::Value, metadata: serde_json::Value, scheduled_at: DateTime, + schedule_name: Option<&str>, ) -> Result<(), Self::Error>; /// Reserve multiple jobs from multiple queues @@ -228,6 +233,7 @@ repository_impl!(QueueJobRepository: payload: serde_json::Value, metadata: serde_json::Value, scheduled_at: DateTime, + schedule_name: Option<&str>, ) -> Result<(), Self::Error>; async fn reserve( @@ -357,7 +363,15 @@ where let metadata = serde_json::to_value(metadata).expect("Could not serialize metadata"); let payload = serde_json::to_value(job).expect("Could not serialize job"); - self.schedule_later(rng, clock, J::QUEUE_NAME, payload, metadata, scheduled_at) - .await + self.schedule_later( + rng, + clock, + J::QUEUE_NAME, + payload, + metadata, + scheduled_at, + None, + ) + .await } } diff --git a/crates/storage/src/queue/mod.rs b/crates/storage/src/queue/mod.rs index d02bee5fd..a41bd4438 100644 --- a/crates/storage/src/queue/mod.rs +++ b/crates/storage/src/queue/mod.rs @@ -6,11 +6,13 @@ //! A module containing repositories for the job queue mod job; +mod schedule; mod tasks; mod worker; pub use self::{ job::{InsertableJob, Job, JobMetadata, QueueJobRepository, QueueJobRepositoryExt}, + schedule::{QueueScheduleRepository, Schedule, ScheduleStatus}, tasks::*, worker::{QueueWorkerRepository, Worker}, }; diff --git a/crates/storage/src/queue/schedule.rs b/crates/storage/src/queue/schedule.rs new file mode 100644 index 000000000..aaee5d325 --- /dev/null +++ b/crates/storage/src/queue/schedule.rs @@ -0,0 +1,59 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! Repository to interact with recurrent scheduled jobs in the job queue + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +pub use cron::Schedule; + +use crate::repository_impl; + +/// [`QueueScheduleRepository::list`] returns a list of [`ScheduleStatus`], +/// which has the name of the schedule and infos about its last run +pub struct ScheduleStatus { + /// Name of the schedule, uniquely identifying it + pub schedule_name: String, + /// When the schedule was last run + pub last_scheduled_at: Option>, + /// Did the last job on this schedule finish? (successfully or not) + pub last_scheduled_job_completed: Option, +} + +/// A [`QueueScheduleRepository`] is used to interact with recurrent scheduled +/// jobs in the job queue. +#[async_trait] +pub trait QueueScheduleRepository: Send + Sync { + /// The error type returned by the repository. + type Error; + + /// Setup the list of schedules in the repository + /// + /// # Parameters + /// + /// * `schedules` - The list of schedules to setup, as a list of (name, + /// schedule) + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn setup(&mut self, schedules: &[(&'static str, Schedule)]) -> Result<(), Self::Error>; + + /// List the schedules in the repository, with the last time they were run + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn list(&mut self) -> Result, Self::Error>; +} + +repository_impl!(QueueScheduleRepository: + async fn setup( + &mut self, + schedules: &[(&'static str, Schedule)], + ) -> Result<(), Self::Error>; + + async fn list(&mut self) -> Result, Self::Error>; +); diff --git a/crates/storage/src/queue/tasks.rs b/crates/storage/src/queue/tasks.rs index a2fe85be4..a193f2037 100644 --- a/crates/storage/src/queue/tasks.rs +++ b/crates/storage/src/queue/tasks.rs @@ -3,6 +3,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use chrono::DateTime; use mas_data_model::{Device, User, UserEmail, UserRecoverySession}; use serde::{Deserialize, Serialize}; use ulid::Ulid; @@ -288,3 +289,11 @@ impl SendAccountRecoveryEmailsJob { impl InsertableJob for SendAccountRecoveryEmailsJob { const QUEUE_NAME: &'static str = "send-account-recovery-email"; } + +/// Cleanup expired tokens +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +pub struct CleanupExpiredTokensJob; + +impl InsertableJob for CleanupExpiredTokensJob { + const QUEUE_NAME: &'static str = "cleanup-expired-tokens"; +} diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 161ef05e3..ab70a287a 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -17,7 +17,7 @@ use crate::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, - queue::{QueueJobRepository, QueueWorkerRepository}, + queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository}, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, @@ -194,6 +194,11 @@ pub trait RepositoryAccess: Send { /// Get a [`QueueJobRepository`] fn queue_job<'c>(&'c mut self) -> Box + 'c>; + + /// Get a [`QueueScheduleRepository`] + fn queue_schedule<'c>( + &'c mut self, + ) -> Box + 'c>; } /// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and @@ -213,7 +218,7 @@ mod impls { OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, - queue::{QueueJobRepository, QueueWorkerRepository}, + queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository}, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, @@ -414,6 +419,12 @@ mod impls { fn queue_job<'c>(&'c mut self) -> Box + 'c> { Box::new(MapErr::new(self.inner.queue_job(), &mut self.mapper)) } + + fn queue_schedule<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.queue_schedule(), &mut self.mapper)) + } } impl RepositoryAccess for Box { @@ -542,5 +553,11 @@ mod impls { fn queue_job<'c>(&'c mut self) -> Box + 'c> { (**self).queue_job() } + + fn queue_schedule<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).queue_schedule() + } } } diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 7a18ca0aa..de62f4998 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -15,6 +15,7 @@ workspace = true anyhow.workspace = true async-stream = "0.3.6" async-trait.workspace = true +cron.workspace = true chrono.workspace = true event-listener = "5.3.1" futures-lite = "2.5.0" diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 5b80abfd8..b64e1defb 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -6,78 +6,35 @@ //! Database-related tasks -use std::str::FromStr; - -use apalis_core::{ - builder::{WorkerBuilder, WorkerFactoryFn}, - context::JobContext, - executor::TokioExecutor, - job::Job, - monitor::Monitor, - utils::timer::TokioTimer, -}; -use apalis_cron::CronStream; -use chrono::{DateTime, Utc}; -use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess}; +use async_trait::async_trait; +use mas_storage::queue::CleanupExpiredTokensJob; use tracing::{debug, info}; use crate::{ - utils::{metrics_layer, trace_layer, TracedJob}, - JobContextExt, State, + new_queue::{JobContext, JobError, RunnableJob}, + State, }; -#[derive(Default, Clone)] -pub struct CleanupExpiredTokensJob { - scheduled: DateTime, -} - -impl From> for CleanupExpiredTokensJob { - fn from(scheduled: DateTime) -> Self { - Self { scheduled } +#[async_trait] +impl RunnableJob for CleanupExpiredTokensJob { + #[tracing::instrument(name = "job.cleanup_expired_tokens", skip_all, err)] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { + let clock = state.clock(); + let mut repo = state.repository().await.map_err(JobError::retry)?; + + let count = repo + .oauth2_access_token() + .cleanup_expired(&clock) + .await + .map_err(JobError::retry)?; + repo.save().await.map_err(JobError::retry)?; + + if count == 0 { + debug!("no token to clean up"); + } else { + info!(count, "cleaned up expired tokens"); + } + + Ok(()) } } - -impl Job for CleanupExpiredTokensJob { - const NAME: &'static str = "cleanup-expired-tokens"; -} - -impl TracedJob for CleanupExpiredTokensJob {} - -pub async fn cleanup_expired_tokens( - job: CleanupExpiredTokensJob, - ctx: JobContext, -) -> Result<(), Box> { - debug!("cleanup expired tokens job scheduled at {}", job.scheduled); - - let state = ctx.state(); - let clock = state.clock(); - let mut repo = state.repository().await?; - - let count = repo.oauth2_access_token().cleanup_expired(&clock).await?; - repo.save().await?; - - if count == 0 { - debug!("no token to clean up"); - } else { - info!(count, "cleaned up expired tokens"); - } - - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, -) -> Monitor { - let schedule = apalis_cron::Schedule::from_str("*/15 * * * * *").unwrap(); - let worker_name = format!("{job}-{suffix}", job = CleanupExpiredTokensJob::NAME); - let worker = WorkerBuilder::new(worker_name) - .stream(CronStream::new(schedule).timer(TokioTimer).to_stream()) - .layer(state.inject()) - .layer(metrics_layer()) - .layer(trace_layer()) - .build_fn(cleanup_expired_tokens); - - monitor.register(worker) -} diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index ad2ede868..ecfd1ca18 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -18,8 +18,7 @@ use rand::SeedableRng; use sqlx::{Pool, Postgres}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; -// TODO: we need to have a way to schedule recurring tasks -// mod database; +mod database; mod email; mod matrix; mod new_queue; @@ -110,14 +109,21 @@ pub async fn init( ); let mut worker = self::new_queue::QueueWorker::new(state, cancellation_token).await?; - worker.register_handler::(); - worker.register_handler::(); - worker.register_handler::(); - worker.register_handler::(); - worker.register_handler::(); - worker.register_handler::(); - worker.register_handler::(); - worker.register_handler::(); + worker + .register_handler::() + .register_handler::() + .register_handler::() + .register_handler::() + .register_handler::() + .register_handler::() + .register_handler::() + .register_handler::() + .register_handler::() + .add_schedule( + "cleanup-expired-tokens", + "*/15 * * * * *".parse()?, + mas_storage::queue::CleanupExpiredTokensJob, + ); task_tracker.spawn(async move { if let Err(e) = worker.run().await { diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index fb2fa7151..afba9c73e 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -8,7 +8,7 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; use mas_storage::{ - queue::{InsertableJob, Job, JobMetadata, Worker}, + queue::{InsertableJob, Job, JobMetadata, Schedule, Worker}, Clock, RepositoryAccess, RepositoryError, }; use mas_storage_pg::{DatabaseError, PgRepository}; @@ -140,6 +140,9 @@ pub enum QueueRunnerError { #[error(transparent)] Database(#[from] DatabaseError), + #[error("Invalid schedule expression")] + InvalidSchedule(#[from] cron::error::Error), + #[error("Worker is not the leader")] NotLeader, } @@ -171,6 +174,13 @@ fn retry_delay(attempt: usize) -> Duration { type JobResult = Result<(), JobError>; type JobFactory = Arc Box + Send + Sync>; +struct ScheduleDefinition { + schedule_name: &'static str, + expression: Schedule, + queue_name: &'static str, + payload: serde_json::Value, +} + pub struct QueueWorker { rng: ChaChaRng, clock: Box, @@ -180,6 +190,7 @@ pub struct QueueWorker { last_heartbeat: DateTime, cancellation_token: CancellationToken, state: State, + schedules: Vec, tracker: JobTracker, } @@ -237,6 +248,7 @@ impl QueueWorker { last_heartbeat: now, cancellation_token, state, + schedules: Vec::new(), tracker: JobTracker::default(), }) } @@ -254,7 +266,27 @@ impl QueueWorker { self } + pub fn add_schedule( + &mut self, + schedule_name: &'static str, + expression: Schedule, + job: T, + ) -> &mut Self { + let payload = serde_json::to_value(job).expect("failed to serialize job payload"); + + self.schedules.push(ScheduleDefinition { + schedule_name, + expression, + queue_name: T::QUEUE_NAME, + payload, + }); + + self + } + pub async fn run(&mut self) -> Result<(), QueueRunnerError> { + self.setup_schedules().await?; + while !self.cancellation_token.is_cancelled() { self.run_loop().await?; } @@ -264,6 +296,34 @@ impl QueueWorker { Ok(()) } + #[tracing::instrument(name = "worker.setup_schedules", skip_all, err)] + pub async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> { + let schedules: Vec<_> = self + .schedules + .iter() + .map(|s| (s.schedule_name, s.expression.clone())) + .collect(); + + // Start a transaction on the existing PgListener connection + let txn = self + .listener + .begin() + .await + .map_err(QueueRunnerError::StartTransaction)?; + + let mut repo = PgRepository::from_conn(txn); + + // Setup the entries in the queue_schedules table + repo.queue_schedule().setup(&schedules).await?; + + repo.into_inner() + .commit() + .await + .map_err(QueueRunnerError::CommitTransaction)?; + + Ok(()) + } + #[tracing::instrument(name = "worker.run_loop", skip_all, err)] async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { self.wait_until_wakeup().await?; @@ -516,6 +576,57 @@ impl QueueWorker { let mut repo = PgRepository::from_conn(locked); + // Look at the state of schedules in the database + let schedules_status = repo.queue_schedule().list().await?; + + let now = self.clock.now(); + for schedule in &self.schedules { + // Find the schedule status from the database + let Some(schedule_status) = schedules_status + .iter() + .find(|s| s.schedule_name == schedule.schedule_name) + else { + tracing::error!( + "Schedule {} was not found in the database", + schedule.schedule_name + ); + continue; + }; + + // Figure out if we should schedule a new job + if let Some(next_time) = schedule_status.last_scheduled_at { + if next_time > now { + // We already have a job scheduled in the future, skip + continue; + } + + if schedule_status.last_scheduled_job_completed == Some(false) { + // The last scheduled job has not completed yet, skip + continue; + } + } + + let next_tick = schedule.expression.after(&now).next().unwrap(); + + tracing::info!( + "Scheduling job for {}, next run at {}", + schedule.schedule_name, + next_tick + ); + + repo.queue_job() + .schedule_later( + &mut self.rng, + &self.clock, + schedule.queue_name, + schedule.payload.clone(), + serde_json::json!({}), + next_tick, + Some(schedule.schedule_name), + ) + .await?; + } + // We also check if the worker is dead, and if so, we shutdown all the dead // workers that haven't checked in the last two minutes repo.queue_worker() From 45b4fcce50465a07cface2ceec8806d17675e3a3 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 5 Dec 2024 17:55:34 +0100 Subject: [PATCH 17/17] Remove the schedule_expression from the database & other fixes --- ...dc8e678278121efbe1f66bcdc24144d684d0.json} | 7 ++-- ...ce891e626a82dcb78ff85f2b815d9329ff936.json | 17 ---------- ...9bb249b18ced57d6a4809dffc23972b3e9423.json | 16 ++++++++++ ...30554dc067d0a6cad963dd7e0c66a80b342bf.json | 16 ++++++++++ ...61540441b14c8206038fdc4a4336bbae3f382.json | 17 ++++++++++ ...738455e94eade48ad5f577e53278cc70dc266.json | 32 +++++++++++++++++++ ...bc9991135065e81af8f77b5beef9405607577.json | 14 ++++++++ .../20241125110803_queue_job_recurrent.sql | 8 ++--- crates/storage-pg/src/queue/schedule.rs | 21 +++++------- crates/storage/src/queue/job.rs | 1 - crates/storage/src/queue/mod.rs | 2 +- crates/storage/src/queue/schedule.rs | 8 ++--- crates/storage/src/queue/tasks.rs | 1 - crates/tasks/src/new_queue.rs | 9 ++---- 14 files changed, 116 insertions(+), 53 deletions(-) rename crates/storage-pg/.sqlx/{query-d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943.json => query-245cab1cf7d9cf4e94cdec91ecb4dc8e678278121efbe1f66bcdc24144d684d0.json} (60%) delete mode 100644 crates/storage-pg/.sqlx/query-3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936.json create mode 100644 crates/storage-pg/.sqlx/query-3e6e3aad53b22fc53eb3ee881b29bb249b18ced57d6a4809dffc23972b3e9423.json create mode 100644 crates/storage-pg/.sqlx/query-5b21644dd3c094b0f2f8babb2c730554dc067d0a6cad963dd7e0c66a80b342bf.json create mode 100644 crates/storage-pg/.sqlx/query-8f4f071f844281fb14ecd99db3261540441b14c8206038fdc4a4336bbae3f382.json create mode 100644 crates/storage-pg/.sqlx/query-9ad4e6e9bfedea476d1f47753e4738455e94eade48ad5f577e53278cc70dc266.json create mode 100644 crates/storage-pg/.sqlx/query-f8182fd162ffb018d4f102fa7ddbc9991135065e81af8f77b5beef9405607577.json diff --git a/crates/storage-pg/.sqlx/query-d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943.json b/crates/storage-pg/.sqlx/query-245cab1cf7d9cf4e94cdec91ecb4dc8e678278121efbe1f66bcdc24144d684d0.json similarity index 60% rename from crates/storage-pg/.sqlx/query-d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943.json rename to crates/storage-pg/.sqlx/query-245cab1cf7d9cf4e94cdec91ecb4dc8e678278121efbe1f66bcdc24144d684d0.json index f87d2dff4..b6635baa8 100644 --- a/crates/storage-pg/.sqlx/query-d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943.json +++ b/crates/storage-pg/.sqlx/query-245cab1cf7d9cf4e94cdec91ecb4dc8e678278121efbe1f66bcdc24144d684d0.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at, scheduled_at, status)\n VALUES ($1, $2, $3, $4, $5, $6, 'scheduled')\n ", + "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at, scheduled_at, schedule_name, status)\n VALUES ($1, $2, $3, $4, $5, $6, $7, 'scheduled')\n ", "describe": { "columns": [], "parameters": { @@ -10,10 +10,11 @@ "Jsonb", "Jsonb", "Timestamptz", - "Timestamptz" + "Timestamptz", + "Text" ] }, "nullable": [] }, - "hash": "d6c4cc9b04086f1b6ffad30d8a859e9fc0bf8a1fe9002dc3854ae28e65fc7943" + "hash": "245cab1cf7d9cf4e94cdec91ecb4dc8e678278121efbe1f66bcdc24144d684d0" } diff --git a/crates/storage-pg/.sqlx/query-3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936.json b/crates/storage-pg/.sqlx/query-3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936.json deleted file mode 100644 index c65354f92..000000000 --- a/crates/storage-pg/.sqlx/query-3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at, attempt, scheduled_at, status)\n SELECT $1, queue_name, payload, metadata, $2, attempt + 1, $3, 'scheduled'\n FROM queue_jobs\n WHERE queue_job_id = $4\n AND status = 'failed'\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz", - "Timestamptz", - "Uuid" - ] - }, - "nullable": [] - }, - "hash": "3355b3b5729d8240297a5ac8111ce891e626a82dcb78ff85f2b815d9329ff936" -} diff --git a/crates/storage-pg/.sqlx/query-3e6e3aad53b22fc53eb3ee881b29bb249b18ced57d6a4809dffc23972b3e9423.json b/crates/storage-pg/.sqlx/query-3e6e3aad53b22fc53eb3ee881b29bb249b18ced57d6a4809dffc23972b3e9423.json new file mode 100644 index 000000000..a930b70e9 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-3e6e3aad53b22fc53eb3ee881b29bb249b18ced57d6a4809dffc23972b3e9423.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_schedules\n SET last_scheduled_at = $1,\n last_scheduled_job_id = $2\n WHERE schedule_name = $3\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid", + "Text" + ] + }, + "nullable": [] + }, + "hash": "3e6e3aad53b22fc53eb3ee881b29bb249b18ced57d6a4809dffc23972b3e9423" +} diff --git a/crates/storage-pg/.sqlx/query-5b21644dd3c094b0f2f8babb2c730554dc067d0a6cad963dd7e0c66a80b342bf.json b/crates/storage-pg/.sqlx/query-5b21644dd3c094b0f2f8babb2c730554dc067d0a6cad963dd7e0c66a80b342bf.json new file mode 100644 index 000000000..ea5a5fb0a --- /dev/null +++ b/crates/storage-pg/.sqlx/query-5b21644dd3c094b0f2f8babb2c730554dc067d0a6cad963dd7e0c66a80b342bf.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_schedules\n SET last_scheduled_at = $1,\n last_scheduled_job_id = $2\n WHERE last_scheduled_job_id = $3\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "5b21644dd3c094b0f2f8babb2c730554dc067d0a6cad963dd7e0c66a80b342bf" +} diff --git a/crates/storage-pg/.sqlx/query-8f4f071f844281fb14ecd99db3261540441b14c8206038fdc4a4336bbae3f382.json b/crates/storage-pg/.sqlx/query-8f4f071f844281fb14ecd99db3261540441b14c8206038fdc4a4336bbae3f382.json new file mode 100644 index 000000000..304e477e6 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-8f4f071f844281fb14ecd99db3261540441b14c8206038fdc4a4336bbae3f382.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at,\n attempt, scheduled_at, schedule_name, status)\n SELECT $1, queue_name, payload, metadata, $2, attempt + 1, $3, schedule_name, 'scheduled'\n FROM queue_jobs\n WHERE queue_job_id = $4\n AND status = 'failed'\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "8f4f071f844281fb14ecd99db3261540441b14c8206038fdc4a4336bbae3f382" +} diff --git a/crates/storage-pg/.sqlx/query-9ad4e6e9bfedea476d1f47753e4738455e94eade48ad5f577e53278cc70dc266.json b/crates/storage-pg/.sqlx/query-9ad4e6e9bfedea476d1f47753e4738455e94eade48ad5f577e53278cc70dc266.json new file mode 100644 index 000000000..6a0c3b950 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-9ad4e6e9bfedea476d1f47753e4738455e94eade48ad5f577e53278cc70dc266.json @@ -0,0 +1,32 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n queue_schedules.schedule_name,\n queue_schedules.last_scheduled_at,\n queue_jobs.status IN ('completed', 'failed') as last_scheduled_job_completed\n FROM queue_schedules\n LEFT JOIN queue_jobs\n ON queue_jobs.queue_job_id = queue_schedules.last_scheduled_job_id\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "schedule_name", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "last_scheduled_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 2, + "name": "last_scheduled_job_completed", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + true, + null + ] + }, + "hash": "9ad4e6e9bfedea476d1f47753e4738455e94eade48ad5f577e53278cc70dc266" +} diff --git a/crates/storage-pg/.sqlx/query-f8182fd162ffb018d4f102fa7ddbc9991135065e81af8f77b5beef9405607577.json b/crates/storage-pg/.sqlx/query-f8182fd162ffb018d4f102fa7ddbc9991135065e81af8f77b5beef9405607577.json new file mode 100644 index 000000000..1a715f579 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-f8182fd162ffb018d4f102fa7ddbc9991135065e81af8f77b5beef9405607577.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_schedules (schedule_name)\n SELECT * FROM UNNEST($1::text[]) AS t (schedule_name)\n ON CONFLICT (schedule_name) DO NOTHING\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "TextArray" + ] + }, + "nullable": [] + }, + "hash": "f8182fd162ffb018d4f102fa7ddbc9991135065e81af8f77b5beef9405607577" +} diff --git a/crates/storage-pg/migrations/20241125110803_queue_job_recurrent.sql b/crates/storage-pg/migrations/20241125110803_queue_job_recurrent.sql index 18c28803c..814e073c4 100644 --- a/crates/storage-pg/migrations/20241125110803_queue_job_recurrent.sql +++ b/crates/storage-pg/migrations/20241125110803_queue_job_recurrent.sql @@ -6,11 +6,7 @@ -- Add a table to track the state of scheduled recurring jobs. CREATE TABLE queue_schedules ( -- A unique name for the schedule - schedule_name TEXT PRIMARY KEY, - - -- The cron expression to use to schedule the job. This is there just for - -- convenience, as this is defined by the backend - schedule_expression TEXT NOT NULL, + schedule_name TEXT NOT NULL PRIMARY KEY, -- The last time the job was scheduled. If NULL, it means that the job was -- never scheduled. @@ -22,7 +18,7 @@ CREATE TABLE queue_schedules ( REFERENCES queue_jobs (queue_job_id) ); --- When a job is scheduled from a recurreing schedule, we keep a column +-- When a job is scheduled from a recurring schedule, we keep a column -- referencing the name of the schedule ALTER TABLE queue_jobs ADD COLUMN schedule_name TEXT diff --git a/crates/storage-pg/src/queue/schedule.rs b/crates/storage-pg/src/queue/schedule.rs index 41f4cb7cf..3594cee7e 100644 --- a/crates/storage-pg/src/queue/schedule.rs +++ b/crates/storage-pg/src/queue/schedule.rs @@ -8,7 +8,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_storage::queue::{QueueScheduleRepository, Schedule, ScheduleStatus}; +use mas_storage::queue::{QueueScheduleRepository, ScheduleStatus}; use sqlx::PgConnection; use crate::{DatabaseError, ExecuteExt}; @@ -45,22 +45,17 @@ impl From for ScheduleStatus { } #[async_trait] -impl<'c> QueueScheduleRepository for PgQueueScheduleRepository<'c> { +impl QueueScheduleRepository for PgQueueScheduleRepository<'_> { type Error = DatabaseError; - async fn setup(&mut self, schedules: &[(&'static str, Schedule)]) -> Result<(), Self::Error> { + async fn setup(&mut self, schedules: &[&'static str]) -> Result<(), Self::Error> { sqlx::query!( r#" - INSERT INTO queue_schedules (schedule_name, schedule_expression) - SELECT * FROM UNNEST($1::text[], $2::text[]) AS t (schedule_name, schedule_expression) - ON CONFLICT (schedule_name) DO UPDATE - SET schedule_expression = EXCLUDED.schedule_expression + INSERT INTO queue_schedules (schedule_name) + SELECT * FROM UNNEST($1::text[]) AS t (schedule_name) + ON CONFLICT (schedule_name) DO NOTHING "#, - &schedules.iter().map(|(name, _)| (*name).to_owned()).collect::>(), - &schedules - .iter() - .map(|(_, schedule)| schedule.source().to_owned()) - .collect::>() + &schedules.iter().map(|&s| s.to_owned()).collect::>(), ) .traced() .execute(&mut *self.conn) @@ -74,7 +69,7 @@ impl<'c> QueueScheduleRepository for PgQueueScheduleRepository<'c> { ScheduleLookup, r#" SELECT - queue_schedules.schedule_name as "schedule_name!", + queue_schedules.schedule_name, queue_schedules.last_scheduled_at, queue_jobs.status IN ('completed', 'failed') as last_scheduled_job_completed FROM queue_schedules diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs index e4c9f7235..298d2f758 100644 --- a/crates/storage/src/queue/job.rs +++ b/crates/storage/src/queue/job.rs @@ -7,7 +7,6 @@ use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; -use cron::Schedule; use opentelemetry::trace::TraceContextExt; use rand_core::RngCore; use serde::{Deserialize, Serialize}; diff --git a/crates/storage/src/queue/mod.rs b/crates/storage/src/queue/mod.rs index a41bd4438..03d969bbb 100644 --- a/crates/storage/src/queue/mod.rs +++ b/crates/storage/src/queue/mod.rs @@ -12,7 +12,7 @@ mod worker; pub use self::{ job::{InsertableJob, Job, JobMetadata, QueueJobRepository, QueueJobRepositoryExt}, - schedule::{QueueScheduleRepository, Schedule, ScheduleStatus}, + schedule::{QueueScheduleRepository, ScheduleStatus}, tasks::*, worker::{QueueWorkerRepository, Worker}, }; diff --git a/crates/storage/src/queue/schedule.rs b/crates/storage/src/queue/schedule.rs index aaee5d325..aaa83e5d2 100644 --- a/crates/storage/src/queue/schedule.rs +++ b/crates/storage/src/queue/schedule.rs @@ -7,7 +7,6 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -pub use cron::Schedule; use crate::repository_impl; @@ -33,13 +32,12 @@ pub trait QueueScheduleRepository: Send + Sync { /// /// # Parameters /// - /// * `schedules` - The list of schedules to setup, as a list of (name, - /// schedule) + /// * `schedules` - The list of schedules to setup /// /// # Errors /// /// Returns an error if the underlying repository fails. - async fn setup(&mut self, schedules: &[(&'static str, Schedule)]) -> Result<(), Self::Error>; + async fn setup(&mut self, schedules: &[&'static str]) -> Result<(), Self::Error>; /// List the schedules in the repository, with the last time they were run /// @@ -52,7 +50,7 @@ pub trait QueueScheduleRepository: Send + Sync { repository_impl!(QueueScheduleRepository: async fn setup( &mut self, - schedules: &[(&'static str, Schedule)], + schedules: &[&'static str], ) -> Result<(), Self::Error>; async fn list(&mut self) -> Result, Self::Error>; diff --git a/crates/storage/src/queue/tasks.rs b/crates/storage/src/queue/tasks.rs index a193f2037..fe8b1f9e5 100644 --- a/crates/storage/src/queue/tasks.rs +++ b/crates/storage/src/queue/tasks.rs @@ -3,7 +3,6 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use chrono::DateTime; use mas_data_model::{Device, User, UserEmail, UserRecoverySession}; use serde::{Deserialize, Serialize}; use ulid::Ulid; diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index afba9c73e..3eab5e53c 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -7,8 +7,9 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; +use cron::Schedule; use mas_storage::{ - queue::{InsertableJob, Job, JobMetadata, Schedule, Worker}, + queue::{InsertableJob, Job, JobMetadata, Worker}, Clock, RepositoryAccess, RepositoryError, }; use mas_storage_pg::{DatabaseError, PgRepository}; @@ -298,11 +299,7 @@ impl QueueWorker { #[tracing::instrument(name = "worker.setup_schedules", skip_all, err)] pub async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> { - let schedules: Vec<_> = self - .schedules - .iter() - .map(|s| (s.schedule_name, s.expression.clone())) - .collect(); + let schedules: Vec<_> = self.schedules.iter().map(|s| s.schedule_name).collect(); // Start a transaction on the existing PgListener connection let txn = self