Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ impl Options {
test_mailer_in_background(&mailer, Duration::from_secs(30));

info!("Starting task worker");
mas_tasks::init(
mas_tasks::init_and_run(
PgRepositoryFactory::new(pool.clone()),
SystemClock::default(),
&mailer,
homeserver_connection.clone(),
url_builder.clone(),
Expand Down
4 changes: 3 additions & 1 deletion crates/cli/src/commands/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use clap::Parser;
use figment::Figment;
use mas_config::{AppConfig, ConfigurationSection};
use mas_router::UrlBuilder;
use mas_storage::SystemClock;
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span};

Expand Down Expand Up @@ -63,8 +64,9 @@ impl Options {
drop(config);

info!("Starting task scheduler");
mas_tasks::init(
mas_tasks::init_and_run(
PgRepositoryFactory::new(pool.clone()),
SystemClock::default(),
&mailer,
conn,
url_builder,
Expand Down
2 changes: 2 additions & 0 deletions crates/handlers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ mas-axum-utils.workspace = true
mas-config.workspace = true
mas-context.workspace = true
mas-data-model.workspace = true
mas-email.workspace = true
mas-http.workspace = true
mas-i18n.workspace = true
mas-iana.workspace = true
Expand All @@ -83,6 +84,7 @@ mas-policy.workspace = true
mas-router.workspace = true
mas-storage.workspace = true
mas-storage-pg.workspace = true
mas-tasks.workspace = true
mas-templates.workspace = true
oauth2-types.workspace = true
zxcvbn.workspace = true
Expand Down
83 changes: 64 additions & 19 deletions crates/handlers/src/admin/v1/users/deactivate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ pub async fn handler(
mod tests {
use chrono::Duration;
use hyper::{Request, StatusCode};
use insta::assert_json_snapshot;
use mas_storage::{Clock, RepositoryAccess, user::UserRepository};
use sqlx::{PgPool, types::Json};
use sqlx::PgPool;

use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};

Expand Down Expand Up @@ -137,15 +138,37 @@ mod tests {
serde_json::json!(state.clock.now())
);

// It should have scheduled a deactivation job for the user
// XXX: we don't have a good way to look for the deactivation job
let job: Json<serde_json::Value> = sqlx::query_scalar(
"SELECT payload FROM queue_jobs WHERE queue_name = 'deactivate-user'",
)
.fetch_one(&pool)
.await
.expect("Deactivation job to be scheduled");
assert_eq!(job["user_id"], serde_json::json!(user.id));
// Make sure to run the jobs in the queue
state.run_jobs_in_queue().await;

let request = Request::get(format!("/api/admin/v1/users/{}", user.id))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();

assert_json_snapshot!(body, @r#"
{
"data": {
"type": "user",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"username": "alice",
"created_at": "2022-01-16T14:40:00Z",
"locked_at": "2022-01-16T14:40:00Z",
"deactivated_at": "2022-01-16T14:40:00Z",
"admin": false
},
"links": {
"self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E"
}
},
"links": {
"self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
"#);
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
Expand Down Expand Up @@ -179,15 +202,37 @@ mod tests {
serde_json::json!(state.clock.now())
);

// It should have scheduled a deactivation job for the user
// XXX: we don't have a good way to look for the deactivation job
let job: Json<serde_json::Value> = sqlx::query_scalar(
"SELECT payload FROM queue_jobs WHERE queue_name = 'deactivate-user'",
)
.fetch_one(&pool)
.await
.expect("Deactivation job to be scheduled");
assert_eq!(job["user_id"], serde_json::json!(user.id));
// Make sure to run the jobs in the queue
state.run_jobs_in_queue().await;

let request = Request::get(format!("/api/admin/v1/users/{}", user.id))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();

assert_json_snapshot!(body, @r#"
{
"data": {
"type": "user",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"username": "alice",
"created_at": "2022-01-16T14:40:00Z",
"locked_at": "2022-01-16T14:40:00Z",
"deactivated_at": "2022-01-16T14:41:00Z",
"admin": false
},
"links": {
"self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E"
}
},
"links": {
"self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
"#);
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
Expand Down
33 changes: 33 additions & 0 deletions crates/handlers/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use mas_axum_utils::{
};
use mas_config::RateLimitingConfig;
use mas_data_model::SiteConfig;
use mas_email::{MailTransport, Mailer};
use mas_i18n::Translator;
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_matrix::{HomeserverConnection, MockHomeserverConnection};
Expand All @@ -39,6 +40,7 @@ use mas_storage::{
clock::MockClock,
};
use mas_storage_pg::PgRepositoryFactory;
use mas_tasks::QueueWorker;
use mas_templates::{SiteConfigExt, Templates};
use oauth2_types::{registration::ClientRegistrationResponse, requests::AccessTokenResponse};
use rand::SeedableRng;
Expand Down Expand Up @@ -113,6 +115,7 @@ pub(crate) struct TestState {
pub rng: Arc<Mutex<ChaChaRng>>,
pub http_client: reqwest::Client,
pub task_tracker: TaskTracker,
queue_worker: Arc<tokio::sync::Mutex<QueueWorker>>,

#[allow(dead_code)] // It is used, as it will cancel the CancellationToken when dropped
cancellation_drop_guard: Arc<DropGuard>,
Expand Down Expand Up @@ -235,6 +238,27 @@ impl TestState {
shutdown_token.child_token(),
);

let mailer = Mailer::new(
templates.clone(),
MailTransport::blackhole(),
"[email protected]".parse().unwrap(),
"[email protected]".parse().unwrap(),
);

let queue_worker = mas_tasks::init(
PgRepositoryFactory::new(pool.clone()),
Arc::clone(&clock),
&mailer,
homeserver_connection.clone(),
url_builder.clone(),
&site_config,
shutdown_token.child_token(),
)
.await
.unwrap();

let queue_worker = Arc::new(tokio::sync::Mutex::new(queue_worker));

Ok(Self {
repository_factory: PgRepositoryFactory::new(pool),
templates,
Expand All @@ -254,10 +278,19 @@ impl TestState {
rng,
http_client,
task_tracker,
queue_worker,
cancellation_drop_guard: Arc::new(shutdown_token.drop_guard()),
})
}

/// Run all the available jobs in the queue.
///
/// Panics if it fails to run the jobs (but not on job failures!)
pub async fn run_jobs_in_queue(&self) {
let mut queue = self.queue_worker.lock().await;
queue.process_all_jobs_in_tests().await.unwrap();
}

/// Reset the test utils to a fresh state, with the same configuration.
pub async fn reset(self) -> Self {
let site_config = self.site_config.clone();
Expand Down
2 changes: 1 addition & 1 deletion crates/storage/src/clock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::sync::{Arc, atomic::AtomicI64};
use chrono::{DateTime, TimeZone, Utc};

/// Represents a clock which can give the current date and time
pub trait Clock: Sync {
pub trait Clock: Send + Sync {
/// Get the current date and time
fn now(&self) -> DateTime<Utc>;
}
Expand Down
2 changes: 1 addition & 1 deletion crates/tasks/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl RunnableJob for CleanupExpiredTokensJob {

let count = repo
.oauth2_access_token()
.cleanup_revoked(&clock)
.cleanup_revoked(clock)
.await
.map_err(JobError::retry)?;
repo.save().await.map_err(JobError::retry)?;
Expand Down
2 changes: 1 addition & 1 deletion crates/tasks/src/email.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl RunnableJob for SendEmailAuthenticationCodeJob {
.user_email()
.add_authentication_code(
&mut rng,
&clock,
clock,
Duration::minutes(5), // TODO: make this configurable
&user_email_authentication,
code,
Expand Down
56 changes: 45 additions & 11 deletions crates/tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ use mas_data_model::SiteConfig;
use mas_email::Mailer;
use mas_matrix::HomeserverConnection;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, RepositoryError, RepositoryFactory, SystemClock};
use mas_storage::{BoxRepository, Clock, RepositoryError, RepositoryFactory};
use mas_storage_pg::PgRepositoryFactory;
use new_queue::QueueRunnerError;
use opentelemetry::metrics::Meter;
use rand::SeedableRng;
use sqlx::{Pool, Postgres};
use tokio_util::{sync::CancellationToken, task::TaskTracker};

pub use crate::new_queue::QueueWorker;

mod database;
mod email;
mod matrix;
Expand All @@ -39,7 +41,7 @@ static METER: LazyLock<Meter> = LazyLock::new(|| {
struct State {
repository_factory: PgRepositoryFactory,
mailer: Mailer,
clock: SystemClock,
clock: Arc<dyn Clock>,
homeserver: Arc<dyn HomeserverConnection>,
url_builder: UrlBuilder,
site_config: SiteConfig,
Expand All @@ -48,7 +50,7 @@ struct State {
impl State {
pub fn new(
repository_factory: PgRepositoryFactory,
clock: SystemClock,
clock: impl Clock + 'static,
mailer: Mailer,
homeserver: impl HomeserverConnection + 'static,
url_builder: UrlBuilder,
Expand All @@ -57,7 +59,7 @@ impl State {
Self {
repository_factory,
mailer,
clock,
clock: Arc::new(clock),
homeserver: Arc::new(homeserver),
url_builder,
site_config,
Expand All @@ -68,8 +70,8 @@ impl State {
self.repository_factory.pool()
}

pub fn clock(&self) -> BoxClock {
Box::new(self.clock.clone())
pub fn clock(&self) -> &dyn Clock {
&self.clock
}

pub fn mailer(&self) -> &Mailer {
Expand Down Expand Up @@ -99,29 +101,31 @@ impl State {
}
}

/// Initialise the workers.
/// Initialise the worker, without running it.
///
/// This is mostly useful for tests.
///
/// # Errors
///
/// This function can fail if the database connection fails.
pub async fn init(
repository_factory: PgRepositoryFactory,
clock: impl Clock + 'static,
mailer: &Mailer,
homeserver: impl HomeserverConnection + 'static,
url_builder: UrlBuilder,
site_config: &SiteConfig,
cancellation_token: CancellationToken,
task_tracker: &TaskTracker,
) -> Result<(), QueueRunnerError> {
) -> Result<QueueWorker, QueueRunnerError> {
let state = State::new(
repository_factory,
SystemClock::default(),
clock,
mailer.clone(),
homeserver,
url_builder,
site_config.clone(),
);
let mut worker = self::new_queue::QueueWorker::new(state, cancellation_token).await?;
let mut worker = QueueWorker::new(state, cancellation_token).await?;

worker
.register_handler::<mas_storage::queue::CleanupExpiredTokensJob>()
Expand Down Expand Up @@ -157,6 +161,36 @@ pub async fn init(
mas_storage::queue::PruneStalePolicyDataJob,
);

Ok(worker)
}

/// Initialise the worker and run it.
///
/// # Errors
///
/// This function can fail if the database connection fails.
#[expect(clippy::too_many_arguments, reason = "this is fine")]
pub async fn init_and_run(
repository_factory: PgRepositoryFactory,
clock: impl Clock + 'static,
mailer: &Mailer,
homeserver: impl HomeserverConnection + 'static,
url_builder: UrlBuilder,
site_config: &SiteConfig,
cancellation_token: CancellationToken,
task_tracker: &TaskTracker,
) -> Result<(), QueueRunnerError> {
let worker = init(
repository_factory,
clock,
mailer,
homeserver,
url_builder,
site_config,
cancellation_token,
)
.await?;

task_tracker.spawn(worker.run());

Ok(())
Expand Down
Loading
Loading