Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions crates/cli/src/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use mas_keystore::{Encrypter, Keystore};
use mas_matrix::HomeserverConnection;
use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock};
use mas_storage_pg::PgRepository;
use mas_storage::{BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, SystemClock, RepositoryFactory};
use mas_storage_pg::PgRepositoryFactory;
use mas_templates::Templates;
use opentelemetry::{KeyValue, metrics::Histogram};
use rand::SeedableRng;
Expand All @@ -31,7 +31,7 @@ use crate::telemetry::METER;

#[derive(Clone)]
pub struct AppState {
pub pool: PgPool,
pub repository_factory: PgRepositoryFactory,
pub templates: Templates,
pub key_store: Keystore,
pub cookie_manager: CookieManager,
Expand All @@ -53,7 +53,7 @@ pub struct AppState {
impl AppState {
/// Init the metrics for the app state.
pub fn init_metrics(&mut self) {
let pool = self.pool.clone();
let pool = self.repository_factory.pool();
METER
.i64_observable_up_down_counter("db.connections.usage")
.with_description("The number of connections that are currently in `state` described by the state attribute.")
Expand All @@ -66,7 +66,7 @@ impl AppState {
})
.build();

let pool = self.pool.clone();
let pool = self.repository_factory.pool();
METER
.i64_observable_up_down_counter("db.connections.max")
.with_description("The maximum number of open connections allowed.")
Expand All @@ -88,14 +88,14 @@ impl AppState {

/// Init the metadata cache in the background
pub fn init_metadata_cache(&self) {
let pool = self.pool.clone();
let factory = self.repository_factory.clone();
let metadata_cache = self.metadata_cache.clone();
let http_client = self.http_client.clone();

tokio::spawn(
LogContext::new("metadata-cache-warmup")
.run(async move || {
let conn = match pool.acquire().await {
let mut repo = match factory.create().await {
Ok(conn) => conn,
Err(e) => {
tracing::error!(
Expand All @@ -106,8 +106,6 @@ impl AppState {
}
};

let mut repo = PgRepository::from_conn(conn);

if let Err(e) = metadata_cache
.warm_up_and_run(
&http_client,
Expand All @@ -127,9 +125,17 @@ impl AppState {
}
}

// XXX(quenting): we only use this for the healthcheck endpoint, checking the db
// should be part of the repository
impl FromRef<AppState> for PgPool {
fn from_ref(input: &AppState) -> Self {
input.pool.clone()
input.repository_factory.pool()
}
}

impl FromRef<AppState> for BoxRepositoryFactory {
fn from_ref(input: &AppState) -> Self {
input.repository_factory.clone().boxed()
}
}

Expand Down Expand Up @@ -359,14 +365,14 @@ impl FromRequestParts<AppState> for RequesterFingerprint {
}

impl FromRequestParts<AppState> for BoxRepository {
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
type Rejection = ErrorWrapper<mas_storage::RepositoryError>;

async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let start = Instant::now();
let repo = PgRepository::from_pool(&state.pool).await?;
let repo = state.repository_factory.create().await?;

// Measure the time it took to create the connection
let duration = start.elapsed();
Expand All @@ -376,6 +382,6 @@ impl FromRequestParts<AppState> for BoxRepository {
histogram.record(duration_ms, &[]);
}

Ok(repo.boxed())
Ok(repo)
}
}
4 changes: 2 additions & 2 deletions crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
use mas_listener::server::Server;
use mas_router::UrlBuilder;
use mas_storage::SystemClock;
use mas_storage_pg::MIGRATOR;
use mas_storage_pg::{PgRepositoryFactory, MIGRATOR};
use sqlx::migrate::Migrate;
use tracing::{Instrument, info, info_span, warn};

Expand Down Expand Up @@ -226,7 +226,7 @@ impl Options {

let state = {
let mut s = AppState {
pool,
repository_factory: PgRepositoryFactory::new(pool),
templates,
key_store,
cookie_manager,
Expand Down
6 changes: 5 additions & 1 deletion crates/storage-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ pub(crate) mod repository;
pub(crate) mod tracing;

pub(crate) use self::errors::DatabaseInconsistencyError;
pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt};
pub use self::{
errors::DatabaseError,
repository::{PgRepository, PgRepositoryFactory},
tracing::ExecuteExt,
};

/// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = {
Expand Down
41 changes: 40 additions & 1 deletion crates/storage-pg/src/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

use std::ops::{Deref, DerefMut};

use async_trait::async_trait;
use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
use mas_storage::{
BoxRepository, MapErr, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError,
RepositoryFactory, RepositoryTransaction,
app_session::AppSessionRepository,
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
Expand Down Expand Up @@ -57,6 +59,43 @@ use crate::{
},
};

/// An implementation of the [`RepositoryFactory`] trait backed by a PostgreSQL
/// connection pool.
#[derive(Clone)]
pub struct PgRepositoryFactory {
pool: PgPool,
}

impl PgRepositoryFactory {
/// Create a new [`PgRepositoryFactory`] from a PostgreSQL connection pool.
#[must_use]
pub fn new(pool: PgPool) -> Self {
Self { pool }
}

/// Box the factory
#[must_use]
pub fn boxed(self) -> BoxRepositoryFactory {
Box::new(self)
}

/// Get the underlying PostgreSQL connection pool
#[must_use]
pub fn pool(&self) -> PgPool {
self.pool.clone()
}
}

#[async_trait]
impl RepositoryFactory for PgRepositoryFactory {
async fn create(&self) -> Result<BoxRepository, RepositoryError> {
Ok(PgRepository::from_pool(&self.pool)
.await
.map_err(RepositoryError::from_error)?
.boxed())
}
}

/// An implementation of the [`Repository`] trait backed by a PostgreSQL
/// transaction.
pub struct PgRepository<C = Transaction<'static, Postgres>> {
Expand Down
3 changes: 2 additions & 1 deletion crates/storage/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ pub use self::{
clock::{Clock, SystemClock},
pagination::{Page, Pagination},
repository::{
BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
BoxRepository, BoxRepositoryFactory, Repository, RepositoryAccess, RepositoryError,
RepositoryFactory, RepositoryTransaction,
},
utils::{BoxClock, BoxRng, MapErr},
};
12 changes: 12 additions & 0 deletions crates/storage/src/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use async_trait::async_trait;
use futures_util::future::BoxFuture;
use thiserror::Error;

Expand All @@ -29,6 +30,17 @@ use crate::{
},
};

/// A [`RepositoryFactory`] is a factory that can create a [`BoxRepository`]
// XXX(quenting): this could be generic over the repository type, but it's annoying to make it dyn-safe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meh, don't overcomplicate it any more than it needs to be. For this, I don't think we should worry about the boxing

#[async_trait]
pub trait RepositoryFactory {
/// Create a new [`BoxRepository`]
async fn create(&self) -> Result<BoxRepository, RepositoryError>;
}

/// A type-erased [`RepositoryFactory`]
pub type BoxRepositoryFactory = Box<dyn RepositoryFactory + Send + Sync + 'static>;

/// A [`Repository`] helps interacting with the underlying storage backend.
pub trait Repository<E>:
RepositoryAccess<Error = E> + RepositoryTransaction<Error = E> + Send
Expand Down