diff --git a/Cargo.lock b/Cargo.lock index e855178f2..a8212b7b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3665,6 +3665,7 @@ dependencies = [ "mas-jose", "mas-storage", "oauth2-types", + "opentelemetry", "opentelemetry-semantic-conventions", "rand 0.8.5", "rand_chacha 0.3.1", diff --git a/crates/cli/src/app_state.rs b/crates/cli/src/app_state.rs index cd4ae44ad..55b592aea 100644 --- a/crates/cli/src/app_state.rs +++ b/crates/cli/src/app_state.rs @@ -4,7 +4,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::{convert::Infallible, net::IpAddr, sync::Arc, time::Instant}; +use std::{convert::Infallible, net::IpAddr, sync::Arc}; use axum::extract::{FromRef, FromRequestParts}; use ipnetwork::IpNetwork; @@ -19,10 +19,12 @@ use mas_keystore::{Encrypter, Keystore}; use mas_matrix::HomeserverConnection; use mas_policy::{Policy, PolicyFactory}; use mas_router::UrlBuilder; -use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock}; -use mas_storage_pg::PgRepository; +use mas_storage::{ + BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, RepositoryFactory, SystemClock, +}; +use mas_storage_pg::PgRepositoryFactory; use mas_templates::Templates; -use opentelemetry::{KeyValue, metrics::Histogram}; +use opentelemetry::KeyValue; use rand::SeedableRng; use sqlx::PgPool; use tracing::Instrument; @@ -31,7 +33,7 @@ use crate::telemetry::METER; #[derive(Clone)] pub struct AppState { - pub pool: PgPool, + pub repository_factory: PgRepositoryFactory, pub templates: Templates, pub key_store: Keystore, pub cookie_manager: CookieManager, @@ -47,13 +49,12 @@ pub struct AppState { pub activity_tracker: ActivityTracker, pub trusted_proxies: Vec, pub limiter: Limiter, - pub conn_acquisition_histogram: Option>, } impl AppState { /// Init the metrics for the app state. pub fn init_metrics(&mut self) { - let pool = self.pool.clone(); + let pool = self.repository_factory.pool(); METER .i64_observable_up_down_counter("db.connections.usage") .with_description("The number of connections that are currently in `state` described by the state attribute.") @@ -66,7 +67,7 @@ impl AppState { }) .build(); - let pool = self.pool.clone(); + let pool = self.repository_factory.pool(); METER .i64_observable_up_down_counter("db.connections.max") .with_description("The maximum number of open connections allowed.") @@ -76,26 +77,18 @@ impl AppState { instrument.observe(i64::from(max_conn), &[]); }) .build(); - - // Track the connection acquisition time - let histogram = METER - .u64_histogram("db.client.connections.create_time") - .with_description("The time it took to create a new connection.") - .with_unit("ms") - .build(); - self.conn_acquisition_histogram = Some(histogram); } /// Init the metadata cache in the background pub fn init_metadata_cache(&self) { - let pool = self.pool.clone(); + let factory = self.repository_factory.clone(); let metadata_cache = self.metadata_cache.clone(); let http_client = self.http_client.clone(); tokio::spawn( LogContext::new("metadata-cache-warmup") .run(async move || { - let conn = match pool.acquire().await { + let mut repo = match factory.create().await { Ok(conn) => conn, Err(e) => { tracing::error!( @@ -106,8 +99,6 @@ impl AppState { } }; - let mut repo = PgRepository::from_conn(conn); - if let Err(e) = metadata_cache .warm_up_and_run( &http_client, @@ -127,9 +118,17 @@ impl AppState { } } +// XXX(quenting): we only use this for the healthcheck endpoint, checking the db +// should be part of the repository impl FromRef for PgPool { fn from_ref(input: &AppState) -> Self { - input.pool.clone() + input.repository_factory.pool() + } +} + +impl FromRef for BoxRepositoryFactory { + fn from_ref(input: &AppState) -> Self { + input.repository_factory.clone().boxed() } } @@ -359,23 +358,13 @@ impl FromRequestParts for RequesterFingerprint { } impl FromRequestParts for BoxRepository { - type Rejection = ErrorWrapper; + type Rejection = ErrorWrapper; async fn from_request_parts( _parts: &mut axum::http::request::Parts, state: &AppState, ) -> Result { - let start = Instant::now(); - let repo = PgRepository::from_pool(&state.pool).await?; - - // Measure the time it took to create the connection - let duration = start.elapsed(); - let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX); - - if let Some(histogram) = &state.conn_acquisition_histogram { - histogram.record(duration_ms, &[]); - } - - Ok(repo.boxed()) + let repo = state.repository_factory.create().await?; + Ok(repo) } } diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index a82d8f059..2c004974f 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -11,6 +11,7 @@ use figment::Figment; use mas_config::{ ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig, }; +use mas_storage_pg::PgRepositoryFactory; use tracing::{info, info_span}; use crate::util::{ @@ -48,7 +49,8 @@ impl Options { if with_dynamic_data { let database_config = DatabaseConfig::extract(figment)?; let pool = database_pool_from_config(&database_config).await?; - load_policy_factory_dynamic_data(&policy_factory, &pool).await?; + let repository_factory = PgRepositoryFactory::new(pool.clone()); + load_policy_factory_dynamic_data(&policy_factory, &repository_factory).await?; } let _instance = policy_factory.instantiate().await?; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index de94ce7b6..4f2fc6205 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -18,7 +18,7 @@ use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache}; use mas_listener::server::Server; use mas_router::UrlBuilder; use mas_storage::SystemClock; -use mas_storage_pg::MIGRATOR; +use mas_storage_pg::{MIGRATOR, PgRepositoryFactory}; use sqlx::migrate::Migrate; use tracing::{Instrument, info, info_span, warn}; @@ -134,7 +134,7 @@ impl Options { load_policy_factory_dynamic_data_continuously( &policy_factory, - &pool, + PgRepositoryFactory::new(pool.clone()).boxed(), shutdown.soft_shutdown_token(), shutdown.task_tracker(), ) @@ -172,7 +172,7 @@ impl Options { info!("Starting task worker"); mas_tasks::init( - &pool, + PgRepositoryFactory::new(pool.clone()), &mailer, homeserver_connection.clone(), url_builder.clone(), @@ -193,7 +193,7 @@ impl Options { // Initialize the activity tracker // Activity is flushed every minute let activity_tracker = ActivityTracker::new( - pool.clone(), + PgRepositoryFactory::new(pool.clone()).boxed(), Duration::from_secs(60), shutdown.task_tracker(), shutdown.soft_shutdown_token(), @@ -215,7 +215,7 @@ impl Options { limiter.start(); let graphql_schema = mas_handlers::graphql_schema( - &pool, + PgRepositoryFactory::new(pool.clone()).boxed(), &policy_factory, homeserver_connection.clone(), site_config.clone(), @@ -226,7 +226,7 @@ impl Options { let state = { let mut s = AppState { - pool, + repository_factory: PgRepositoryFactory::new(pool), templates, key_store, cookie_manager, @@ -242,7 +242,6 @@ impl Options { activity_tracker, trusted_proxies, limiter, - conn_acquisition_histogram: None, }; s.init_metrics(); s.init_metadata_cache(); diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index da16e848a..f13a1ae3c 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -10,6 +10,7 @@ use clap::Parser; use figment::Figment; use mas_config::{AppConfig, ConfigurationSection}; use mas_router::UrlBuilder; +use mas_storage_pg::PgRepositoryFactory; use tracing::{info, info_span}; use crate::{ @@ -63,7 +64,7 @@ impl Options { info!("Starting task scheduler"); mas_tasks::init( - &pool, + PgRepositoryFactory::new(pool.clone()), &mailer, conn, url_builder, diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index e66966103..84e94aeea 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -20,8 +20,7 @@ use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection}; use mas_matrix_synapse::SynapseConnection; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; -use mas_storage::RepositoryAccess; -use mas_storage_pg::PgRepository; +use mas_storage::{BoxRepositoryFactory, RepositoryAccess, RepositoryFactory}; use mas_templates::{SiteConfigExt, Templates}; use sqlx::{ ConnectOptions, Executor, PgConnection, PgPool, @@ -400,14 +399,13 @@ pub async fn database_connection_from_config_with_options( // XXX: this could be put somewhere else? pub async fn load_policy_factory_dynamic_data_continuously( policy_factory: &Arc, - pool: &PgPool, + repository_factory: BoxRepositoryFactory, cancellation_token: CancellationToken, task_tracker: &TaskTracker, ) -> Result<(), anyhow::Error> { let policy_factory = policy_factory.clone(); - let pool = pool.clone(); - load_policy_factory_dynamic_data(&policy_factory, &pool).await?; + load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await?; task_tracker.spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(60)); @@ -420,7 +418,9 @@ pub async fn load_policy_factory_dynamic_data_continuously( _ = interval.tick() => {} } - if let Err(err) = load_policy_factory_dynamic_data(&policy_factory, &pool).await { + if let Err(err) = + load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await + { tracing::error!( error = ?err, "Failed to load policy factory dynamic data" @@ -438,9 +438,10 @@ pub async fn load_policy_factory_dynamic_data_continuously( #[tracing::instrument(name = "policy.load_dynamic_data", skip_all)] pub async fn load_policy_factory_dynamic_data( policy_factory: &PolicyFactory, - pool: &PgPool, + repository_factory: &(dyn RepositoryFactory + Send + Sync), ) -> Result<(), anyhow::Error> { - let mut repo = PgRepository::from_pool(pool) + let mut repo = repository_factory + .create() .await .context("Failed to acquire database connection")?; diff --git a/crates/handlers/src/activity_tracker/mod.rs b/crates/handlers/src/activity_tracker/mod.rs index 3f7511af6..56785e236 100644 --- a/crates/handlers/src/activity_tracker/mod.rs +++ b/crates/handlers/src/activity_tracker/mod.rs @@ -11,8 +11,7 @@ use std::net::IpAddr; use chrono::{DateTime, Utc}; use mas_data_model::{BrowserSession, CompatSession, Session}; -use mas_storage::Clock; -use sqlx::PgPool; +use mas_storage::{BoxRepositoryFactory, Clock}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use ulid::Ulid; @@ -61,12 +60,12 @@ impl ActivityTracker { /// time, when the cancellation token is cancelled. #[must_use] pub fn new( - pool: PgPool, + repository_factory: BoxRepositoryFactory, flush_interval: std::time::Duration, task_tracker: &TaskTracker, cancellation_token: CancellationToken, ) -> Self { - let worker = Worker::new(pool); + let worker = Worker::new(repository_factory); let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE); let tracker = ActivityTracker { channel: sender }; diff --git a/crates/handlers/src/activity_tracker/worker.rs b/crates/handlers/src/activity_tracker/worker.rs index 80853f0fe..4787964ee 100644 --- a/crates/handlers/src/activity_tracker/worker.rs +++ b/crates/handlers/src/activity_tracker/worker.rs @@ -7,12 +7,13 @@ use std::{collections::HashMap, net::IpAddr}; use chrono::{DateTime, Utc}; -use mas_storage::{RepositoryAccess, RepositoryError, user::BrowserSessionRepository}; +use mas_storage::{ + BoxRepositoryFactory, RepositoryAccess, RepositoryError, user::BrowserSessionRepository, +}; use opentelemetry::{ Key, KeyValue, metrics::{Counter, Gauge, Histogram}, }; -use sqlx::PgPool; use tokio_util::sync::CancellationToken; use ulid::Ulid; @@ -43,7 +44,7 @@ struct ActivityRecord { /// Handles writing activity records to the database. pub struct Worker { - pool: PgPool, + repository_factory: BoxRepositoryFactory, pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>, pending_records_gauge: Gauge, message_counter: Counter, @@ -51,7 +52,7 @@ pub struct Worker { } impl Worker { - pub(crate) fn new(pool: PgPool) -> Self { + pub(crate) fn new(repository_factory: BoxRepositoryFactory) -> Self { let message_counter = METER .u64_counter("mas.activity_tracker.messages") .with_description("The number of messages received by the activity tracker") @@ -89,7 +90,7 @@ impl Worker { pending_records_gauge.record(0, &[]); Self { - pool, + repository_factory, pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS), pending_records_gauge, message_counter, @@ -218,11 +219,7 @@ impl Worker { #[tracing::instrument(name = "activity_tracker.flush", skip(self))] async fn try_flush(&mut self) -> Result<(), RepositoryError> { let pending_records = &self.pending_records; - - let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool) - .await - .map_err(RepositoryError::from_error)? - .boxed(); + let mut repo = self.repository_factory.create().await?; let mut browser_sessions = Vec::new(); let mut oauth2_sessions = Vec::new(); diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 76148df75..3dbbb0d93 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -14,11 +14,12 @@ use mas_axum_utils::record_error; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, SiteConfig, TokenType, User}; use mas_matrix::HomeserverConnection; use mas_storage::{ - BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, + BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, Clock, RepositoryAccess, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, }, + queue::{QueueJobRepositoryExt as _, SyncDevicesJob}, user::{UserPasswordRepository, UserRepository}, }; use opentelemetry::{Key, KeyValue, metrics::Counter}; @@ -268,7 +269,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - mut repo: BoxRepository, + State(repository_factory): State, activity_tracker: BoundActivityTracker, State(homeserver): State>, State(site_config): State, @@ -279,6 +280,7 @@ pub(crate) async fn post( ) -> Result { let user_agent = user_agent.map(|ua| ua.as_str().to_owned()); let login_type = input.credentials.login_type(); + let mut repo = repository_factory.create().await?; let (mut session, user) = match (password_manager.is_enabled(), input.credentials) { ( true, @@ -301,6 +303,9 @@ pub(crate) async fn post( } }; + // Try getting the localpart out of the MXID + let username = homeserver.localpart(&user).unwrap_or(&user); + user_password_login( &mut rng, &clock, @@ -308,8 +313,7 @@ pub(crate) async fn post( &limiter, requester, &mut repo, - &homeserver, - user, + username, password, input.device_id, // TODO check for validity input.initial_device_display_name, @@ -322,7 +326,6 @@ pub(crate) async fn post( &mut rng, &clock, &mut repo, - &homeserver, &token, input.device_id, input.initial_device_display_name, @@ -368,12 +371,53 @@ pub(crate) async fn post( None }; + // Ideally, we'd keep the lock whilst we actually create the device, but we + // really want to stop holding the transaction while we talk to the + // homeserver. + // + // In practice, this is fine, because: + // - the session exists after we commited the transaction, so a sync job won't + // try to delete it + // - we've acquired a lock on the user before creating the session, meaning + // we've made sure that sync jobs finished before we create the new session + // - we're in the read-commited isolation level, which means the sync will see + // what we've committed and won't try to delete the session once we release + // the lock repo.save().await?; activity_tracker .record_compat_session(&clock, &session) .await; + // This session will have for sure the device on it, both methods create a + // device + let Some(device) = &session.device else { + unreachable!() + }; + + // Now we can create the device on the homeserver, without holding the + // transaction + if let Err(err) = homeserver + .create_device(&user_id, device.as_str(), session.human_name.as_deref()) + .await + { + // Something went wrong, let's end this session and schedule a device sync + let mut repo = repository_factory.create().await?; + let session = repo.compat_session().finish(&clock, session).await?; + + repo.queue_job() + .schedule_job( + &mut rng, + &clock, + SyncDevicesJob::new_for_id(session.user_id), + ) + .await?; + + repo.save().await?; + + return Err(RouteError::ProvisionDeviceFailed(err)); + } + LOGIN_COUNTER.add( 1, &[ @@ -395,7 +439,6 @@ async fn token_login( rng: &mut (dyn RngCore + Send), clock: &dyn Clock, repo: &mut BoxRepository, - homeserver: &dyn HomeserverConnection, token: &str, requested_device_id: Option, initial_device_display_name: Option, @@ -461,7 +504,8 @@ async fn token_login( return Err(RouteError::InvalidLoginToken); } - // Lock the user sync to make sure we don't get into a race condition + // We're about to create a device, let's explicitly acquire a lock, so that + // any concurrent sync will read after we've committed repo.user() .acquire_lock_for_sync(&browser_session.user) .await?; @@ -471,20 +515,14 @@ async fn token_login( } else { Device::generate(rng) }; - let mxid = homeserver.mxid(&browser_session.user.username); - homeserver - .create_device( - &mxid, - device.as_str(), - initial_device_display_name.as_deref(), - ) - .await - .map_err(RouteError::ProvisionDeviceFailed)?; repo.app_session() .finish_sessions_to_replace_device(clock, &browser_session.user, &device) .await?; + // We first create the session in the database, commit the transaction, then + // create it on the homeserver, scheduling a device sync job afterwards to + // make sure we don't end up in an inconsistent state. let compat_session = repo .compat_session() .add( @@ -512,15 +550,11 @@ async fn user_password_login( limiter: &Limiter, requester: RequesterFingerprint, repo: &mut BoxRepository, - homeserver: &dyn HomeserverConnection, - username: String, + username: &str, password: String, requested_device_id: Option, initial_device_display_name: Option, ) -> Result<(CompatSession, User), RouteError> { - // Try getting the localpart out of the MXID - let username = homeserver.localpart(&username).unwrap_or(&username); - // Find the user let user = repo .user() @@ -566,25 +600,16 @@ async fn user_password_login( .await?; } - // Lock the user sync to make sure we don't get into a race condition + // We're about to create a device, let's explicitly acquire a lock, so that + // any concurrent sync will read after we've committed repo.user().acquire_lock_for_sync(&user).await?; - let mxid = homeserver.mxid(&user.username); - // Now that the user credentials have been verified, start a new compat session let device = if let Some(requested_device_id) = requested_device_id { Device::from(requested_device_id) } else { Device::generate(&mut rng) }; - homeserver - .create_device( - &mxid, - device.as_str(), - initial_device_display_name.as_deref(), - ) - .await - .map_err(RouteError::ProvisionDeviceFailed)?; repo.app_session() .finish_sessions_to_replace_device(clock, &user, &device) diff --git a/crates/handlers/src/graphql/mod.rs b/crates/handlers/src/graphql/mod.rs index 9428abe7b..cfedd69e9 100644 --- a/crates/handlers/src/graphql/mod.rs +++ b/crates/handlers/src/graphql/mod.rs @@ -32,12 +32,12 @@ use mas_data_model::{BrowserSession, Session, SiteConfig, User}; use mas_matrix::HomeserverConnection; use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_router::UrlBuilder; -use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock}; -use mas_storage_pg::PgRepository; +use mas_storage::{ + BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, Clock, RepositoryError, SystemClock, +}; use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME}; use rand::{SeedableRng, thread_rng}; use rand_chacha::ChaChaRng; -use sqlx::PgPool; use state::has_session_ended; use tracing::{Instrument, info_span}; use ulid::Ulid; @@ -69,7 +69,7 @@ pub struct ExtraRouterParameters { } struct GraphQLState { - pool: PgPool, + repository_factory: BoxRepositoryFactory, homeserver_connection: Arc, policy_factory: Arc, site_config: SiteConfig, @@ -81,11 +81,7 @@ struct GraphQLState { #[async_trait::async_trait] impl state::State for GraphQLState { async fn repository(&self) -> Result { - let repo = PgRepository::from_pool(&self.pool) - .await - .map_err(RepositoryError::from_error)?; - - Ok(repo.boxed()) + self.repository_factory.create().await } async fn policy(&self) -> Result { @@ -128,7 +124,7 @@ impl state::State for GraphQLState { #[must_use] pub fn schema( - pool: &PgPool, + repository_factory: BoxRepositoryFactory, policy_factory: &Arc, homeserver_connection: impl HomeserverConnection + 'static, site_config: SiteConfig, @@ -137,7 +133,7 @@ pub fn schema( limiter: Limiter, ) -> Schema { let state = GraphQLState { - pool: pool.clone(), + repository_factory, policy_factory: Arc::clone(policy_factory), homeserver_connection: Arc::new(homeserver_connection), site_config, diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 36028cb3e..40e2f5f26 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -42,7 +42,7 @@ use mas_keystore::{Encrypter, Keystore}; use mas_matrix::HomeserverConnection; use mas_policy::Policy; use mas_router::{Route, UrlBuilder}; -use mas_storage::{BoxClock, BoxRepository, BoxRng}; +use mas_storage::{BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng}; use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates}; use opentelemetry::metrics::Meter; use sqlx::PgPool; @@ -265,6 +265,7 @@ where Arc: FromRef, PasswordManager: FromRef, Limiter: FromRef, + BoxRepositoryFactory: FromRef, BoundActivityTracker: FromRequestParts, RequesterFingerprint: FromRequestParts, BoxRepository: FromRequestParts, diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index b6d9fba9d..d69bda190 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -34,8 +34,11 @@ use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; use mas_matrix::{HomeserverConnection, MockHomeserverConnection}; use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_router::{SimpleRoute, UrlBuilder}; -use mas_storage::{BoxClock, BoxRepository, BoxRng, clock::MockClock}; -use mas_storage_pg::{DatabaseError, PgRepository}; +use mas_storage::{ + BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, RepositoryError, RepositoryFactory, + clock::MockClock, +}; +use mas_storage_pg::PgRepositoryFactory; use mas_templates::{SiteConfigExt, Templates}; use oauth2_types::{registration::ClientRegistrationResponse, requests::AccessTokenResponse}; use rand::SeedableRng; @@ -92,7 +95,7 @@ pub(crate) async fn policy_factory( #[derive(Clone)] pub(crate) struct TestState { - pub pool: PgPool, + pub repository_factory: PgRepositoryFactory, pub templates: Templates, pub key_store: Keystore, pub cookie_manager: CookieManager, @@ -209,7 +212,7 @@ impl TestState { let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap(); let graphql_state = TestGraphQLState { - pool: pool.clone(), + repository_factory: PgRepositoryFactory::new(pool.clone()).boxed(), policy_factory: Arc::clone(&policy_factory), homeserver_connection: Arc::clone(&homeserver_connection), site_config: site_config.clone(), @@ -224,14 +227,14 @@ impl TestState { let graphql_schema = graphql::schema_builder().data(state).finish(); let activity_tracker = ActivityTracker::new( - pool.clone(), + PgRepositoryFactory::new(pool.clone()).boxed(), std::time::Duration::from_secs(60), &task_tracker, shutdown_token.child_token(), ); Ok(Self { - pool, + repository_factory: PgRepositoryFactory::new(pool), templates, key_store, cookie_manager, @@ -256,7 +259,7 @@ impl TestState { /// Reset the test utils to a fresh state, with the same configuration. pub async fn reset(self) -> Self { let site_config = self.site_config.clone(); - let pool = self.pool.clone(); + let pool = self.repository_factory.pool(); let task_tracker = self.task_tracker.clone(); // This should trigger the cancellation drop guard @@ -351,9 +354,8 @@ impl TestState { access_token } - pub async fn repository(&self) -> Result { - let repo = PgRepository::from_pool(&self.pool).await?; - Ok(repo.boxed()) + pub async fn repository(&self) -> Result { + self.repository_factory.create().await } /// Returns a new random number generator. @@ -393,7 +395,7 @@ impl TestState { } struct TestGraphQLState { - pool: PgPool, + repository_factory: BoxRepositoryFactory, homeserver_connection: Arc, site_config: SiteConfig, policy_factory: Arc, @@ -407,11 +409,7 @@ struct TestGraphQLState { #[async_trait::async_trait] impl graphql::State for TestGraphQLState { async fn repository(&self) -> Result { - let repo = PgRepository::from_pool(&self.pool) - .await - .map_err(mas_storage::RepositoryError::from_error)?; - - Ok(repo.boxed()) + self.repository_factory.create().await } async fn policy(&self) -> Result { @@ -451,7 +449,13 @@ impl graphql::State for TestGraphQLState { impl FromRef for PgPool { fn from_ref(input: &TestState) -> Self { - input.pool.clone() + input.repository_factory.pool() + } +} + +impl FromRef for BoxRepositoryFactory { + fn from_ref(input: &TestState) -> Self { + input.repository_factory.clone().boxed() } } @@ -598,14 +602,14 @@ impl FromRequestParts for BoxRng { } impl FromRequestParts for BoxRepository { - type Rejection = ErrorWrapper; + type Rejection = ErrorWrapper; async fn from_request_parts( _parts: &mut axum::http::request::Parts, state: &TestState, ) -> Result { - let repo = PgRepository::from_pool(&state.pool).await?; - Ok(repo.boxed()) + let repo = state.repository_factory.create().await?; + Ok(repo) } } diff --git a/crates/storage-pg/Cargo.toml b/crates/storage-pg/Cargo.toml index 6b45fb0a6..2a5e50447 100644 --- a/crates/storage-pg/Cargo.toml +++ b/crates/storage-pg/Cargo.toml @@ -21,6 +21,7 @@ serde_json.workspace = true thiserror.workspace = true tracing.workspace = true futures-util.workspace = true +opentelemetry.workspace = true opentelemetry-semantic-conventions.workspace = true rand.workspace = true diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index 8971488a5..30882cfa8 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -175,10 +175,15 @@ pub(crate) mod iden; pub(crate) mod pagination; pub(crate) mod policy_data; pub(crate) mod repository; +pub(crate) mod telemetry; pub(crate) mod tracing; pub(crate) use self::errors::DatabaseInconsistencyError; -pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt}; +pub use self::{ + errors::DatabaseError, + repository::{PgRepository, PgRepositoryFactory}, + tracing::ExecuteExt, +}; /// Embedded migrations, allowing them to run on startup pub static MIGRATOR: Migrator = { diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 901f1fd45..c6668c2e4 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -6,9 +6,11 @@ use std::ops::{Deref, DerefMut}; +use async_trait::async_trait; use futures_util::{FutureExt, TryFutureExt, future::BoxFuture}; use mas_storage::{ - BoxRepository, MapErr, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction, + BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError, + RepositoryFactory, RepositoryTransaction, app_session::AppSessionRepository, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, @@ -46,6 +48,7 @@ use crate::{ job::PgQueueJobRepository, schedule::PgQueueScheduleRepository, worker::PgQueueWorkerRepository, }, + telemetry::DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, @@ -57,6 +60,51 @@ use crate::{ }, }; +/// An implementation of the [`RepositoryFactory`] trait backed by a PostgreSQL +/// connection pool. +#[derive(Clone)] +pub struct PgRepositoryFactory { + pool: PgPool, +} + +impl PgRepositoryFactory { + /// Create a new [`PgRepositoryFactory`] from a PostgreSQL connection pool. + #[must_use] + pub fn new(pool: PgPool) -> Self { + Self { pool } + } + + /// Box the factory + #[must_use] + pub fn boxed(self) -> BoxRepositoryFactory { + Box::new(self) + } + + /// Get the underlying PostgreSQL connection pool + #[must_use] + pub fn pool(&self) -> PgPool { + self.pool.clone() + } +} + +#[async_trait] +impl RepositoryFactory for PgRepositoryFactory { + async fn create(&self) -> Result { + let start = std::time::Instant::now(); + let repo = PgRepository::from_pool(&self.pool) + .await + .map_err(RepositoryError::from_error)? + .boxed(); + + // Measure the time it took to create the connection + let duration = start.elapsed(); + let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX); + DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM.record(duration_ms, &[]); + + Ok(repo) + } +} + /// An implementation of the [`Repository`] trait backed by a PostgreSQL /// transaction. pub struct PgRepository> { diff --git a/crates/storage-pg/src/telemetry.rs b/crates/storage-pg/src/telemetry.rs new file mode 100644 index 000000000..93c74e74f --- /dev/null +++ b/crates/storage-pg/src/telemetry.rs @@ -0,0 +1,31 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use std::sync::LazyLock; + +use opentelemetry::{ + InstrumentationScope, + metrics::{Histogram, Meter}, +}; +use opentelemetry_semantic_conventions as semcov; + +static SCOPE: LazyLock = LazyLock::new(|| { + InstrumentationScope::builder(env!("CARGO_PKG_NAME")) + .with_version(env!("CARGO_PKG_VERSION")) + .with_schema_url(semcov::SCHEMA_URL) + .build() +}); + +static METER: LazyLock = + LazyLock::new(|| opentelemetry::global::meter_with_scope(SCOPE.clone())); + +pub(crate) static DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM: LazyLock> = + LazyLock::new(|| { + METER + .u64_histogram("db.client.connections.create_time") + .with_description("The time it took to create a new connection.") + .with_unit("ms") + .build() + }); diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 923113a6a..07d8bd97c 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -128,7 +128,8 @@ pub use self::{ clock::{Clock, SystemClock}, pagination::{Page, Pagination}, repository::{ - BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction, + BoxRepository, BoxRepositoryFactory, Repository, RepositoryAccess, RepositoryError, + RepositoryFactory, RepositoryTransaction, }, utils::{BoxClock, BoxRng, MapErr}, }; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 2f051493c..93c43d469 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -4,6 +4,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use async_trait::async_trait; use futures_util::future::BoxFuture; use thiserror::Error; @@ -29,6 +30,18 @@ use crate::{ }, }; +/// A [`RepositoryFactory`] is a factory that can create a [`BoxRepository`] +// XXX(quenting): this could be generic over the repository type, but it's annoying to make it +// dyn-safe +#[async_trait] +pub trait RepositoryFactory { + /// Create a new [`BoxRepository`] + async fn create(&self) -> Result; +} + +/// A type-erased [`RepositoryFactory`] +pub type BoxRepositoryFactory = Box; + /// A [`Repository`] helps interacting with the underlying storage backend. pub trait Repository: RepositoryAccess + RepositoryTransaction + Send diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 5ef7f5e84..cb1b16469 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -10,8 +10,8 @@ use mas_data_model::SiteConfig; use mas_email::Mailer; use mas_matrix::HomeserverConnection; use mas_router::UrlBuilder; -use mas_storage::{BoxClock, BoxRepository, RepositoryError, SystemClock}; -use mas_storage_pg::PgRepository; +use mas_storage::{BoxClock, BoxRepository, RepositoryError, RepositoryFactory, SystemClock}; +use mas_storage_pg::PgRepositoryFactory; use new_queue::QueueRunnerError; use opentelemetry::metrics::Meter; use rand::SeedableRng; @@ -37,7 +37,7 @@ static METER: LazyLock = LazyLock::new(|| { #[derive(Clone)] struct State { - pool: Pool, + repository_factory: PgRepositoryFactory, mailer: Mailer, clock: SystemClock, homeserver: Arc, @@ -47,7 +47,7 @@ struct State { impl State { pub fn new( - pool: Pool, + repository_factory: PgRepositoryFactory, clock: SystemClock, mailer: Mailer, homeserver: impl HomeserverConnection + 'static, @@ -55,7 +55,7 @@ impl State { site_config: SiteConfig, ) -> Self { Self { - pool, + repository_factory, mailer, clock, homeserver: Arc::new(homeserver), @@ -64,8 +64,8 @@ impl State { } } - pub fn pool(&self) -> &Pool { - &self.pool + pub fn pool(&self) -> Pool { + self.repository_factory.pool() } pub fn clock(&self) -> BoxClock { @@ -83,12 +83,7 @@ impl State { } pub async fn repository(&self) -> Result { - let repo = PgRepository::from_pool(self.pool()) - .await - .map_err(RepositoryError::from_error)? - .boxed(); - - Ok(repo) + self.repository_factory.create().await } pub fn matrix_connection(&self) -> &dyn HomeserverConnection { @@ -110,7 +105,7 @@ impl State { /// /// This function can fail if the database connection fails. pub async fn init( - pool: &Pool, + repository_factory: PgRepositoryFactory, mailer: &Mailer, homeserver: impl HomeserverConnection + 'static, url_builder: UrlBuilder, @@ -119,7 +114,7 @@ pub async fn init( task_tracker: &TaskTracker, ) -> Result<(), QueueRunnerError> { let state = State::new( - pool.clone(), + repository_factory, SystemClock::default(), mailer.clone(), homeserver, diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index ea055e2f8..1c83b1720 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -224,7 +224,7 @@ impl QueueWorker { let mut rng = state.rng(); let clock = state.clock(); - let mut listener = PgListener::connect_with(state.pool()) + let mut listener = PgListener::connect_with(&state.pool()) .await .map_err(QueueRunnerError::SetupListener)?;