Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 23 additions & 34 deletions crates/cli/src/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::{convert::Infallible, net::IpAddr, sync::Arc, time::Instant};
use std::{convert::Infallible, net::IpAddr, sync::Arc};

use axum::extract::{FromRef, FromRequestParts};
use ipnetwork::IpNetwork;
Expand All @@ -19,10 +19,12 @@ use mas_keystore::{Encrypter, Keystore};
use mas_matrix::HomeserverConnection;
use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock};
use mas_storage_pg::PgRepository;
use mas_storage::{
BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, RepositoryFactory, SystemClock,
};
use mas_storage_pg::PgRepositoryFactory;
use mas_templates::Templates;
use opentelemetry::{KeyValue, metrics::Histogram};
use opentelemetry::KeyValue;
use rand::SeedableRng;
use sqlx::PgPool;
use tracing::Instrument;
Expand All @@ -31,7 +33,7 @@ use crate::telemetry::METER;

#[derive(Clone)]
pub struct AppState {
pub pool: PgPool,
pub repository_factory: PgRepositoryFactory,
pub templates: Templates,
pub key_store: Keystore,
pub cookie_manager: CookieManager,
Expand All @@ -47,13 +49,12 @@ pub struct AppState {
pub activity_tracker: ActivityTracker,
pub trusted_proxies: Vec<IpNetwork>,
pub limiter: Limiter,
pub conn_acquisition_histogram: Option<Histogram<u64>>,
}

impl AppState {
/// Init the metrics for the app state.
pub fn init_metrics(&mut self) {
let pool = self.pool.clone();
let pool = self.repository_factory.pool();
METER
.i64_observable_up_down_counter("db.connections.usage")
.with_description("The number of connections that are currently in `state` described by the state attribute.")
Expand All @@ -66,7 +67,7 @@ impl AppState {
})
.build();

let pool = self.pool.clone();
let pool = self.repository_factory.pool();
METER
.i64_observable_up_down_counter("db.connections.max")
.with_description("The maximum number of open connections allowed.")
Expand All @@ -76,26 +77,18 @@ impl AppState {
instrument.observe(i64::from(max_conn), &[]);
})
.build();

// Track the connection acquisition time
let histogram = METER
.u64_histogram("db.client.connections.create_time")
.with_description("The time it took to create a new connection.")
.with_unit("ms")
.build();
self.conn_acquisition_histogram = Some(histogram);
}

/// Init the metadata cache in the background
pub fn init_metadata_cache(&self) {
let pool = self.pool.clone();
let factory = self.repository_factory.clone();
let metadata_cache = self.metadata_cache.clone();
let http_client = self.http_client.clone();

tokio::spawn(
LogContext::new("metadata-cache-warmup")
.run(async move || {
let conn = match pool.acquire().await {
let mut repo = match factory.create().await {
Ok(conn) => conn,
Err(e) => {
tracing::error!(
Expand All @@ -106,8 +99,6 @@ impl AppState {
}
};

let mut repo = PgRepository::from_conn(conn);

if let Err(e) = metadata_cache
.warm_up_and_run(
&http_client,
Expand All @@ -127,9 +118,17 @@ impl AppState {
}
}

// XXX(quenting): we only use this for the healthcheck endpoint, checking the db
// should be part of the repository
impl FromRef<AppState> for PgPool {
fn from_ref(input: &AppState) -> Self {
input.pool.clone()
input.repository_factory.pool()
}
}

impl FromRef<AppState> for BoxRepositoryFactory {
fn from_ref(input: &AppState) -> Self {
input.repository_factory.clone().boxed()
}
}

Expand Down Expand Up @@ -359,23 +358,13 @@ impl FromRequestParts<AppState> for RequesterFingerprint {
}

impl FromRequestParts<AppState> for BoxRepository {
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
type Rejection = ErrorWrapper<mas_storage::RepositoryError>;

async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let start = Instant::now();
let repo = PgRepository::from_pool(&state.pool).await?;

// Measure the time it took to create the connection
let duration = start.elapsed();
let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);

if let Some(histogram) = &state.conn_acquisition_histogram {
histogram.record(duration_ms, &[]);
}

Ok(repo.boxed())
let repo = state.repository_factory.create().await?;
Ok(repo)
}
}
4 changes: 3 additions & 1 deletion crates/cli/src/commands/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use figment::Figment;
use mas_config::{
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
};
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span};

use crate::util::{
Expand Down Expand Up @@ -48,7 +49,8 @@ impl Options {
if with_dynamic_data {
let database_config = DatabaseConfig::extract(figment)?;
let pool = database_pool_from_config(&database_config).await?;
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
let repository_factory = PgRepositoryFactory::new(pool.clone());
load_policy_factory_dynamic_data(&policy_factory, &repository_factory).await?;
}

let _instance = policy_factory.instantiate().await?;
Expand Down
13 changes: 6 additions & 7 deletions crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
use mas_listener::server::Server;
use mas_router::UrlBuilder;
use mas_storage::SystemClock;
use mas_storage_pg::MIGRATOR;
use mas_storage_pg::{MIGRATOR, PgRepositoryFactory};
use sqlx::migrate::Migrate;
use tracing::{Instrument, info, info_span, warn};

Expand Down Expand Up @@ -134,7 +134,7 @@ impl Options {

load_policy_factory_dynamic_data_continuously(
&policy_factory,
&pool,
PgRepositoryFactory::new(pool.clone()).boxed(),
shutdown.soft_shutdown_token(),
shutdown.task_tracker(),
)
Expand Down Expand Up @@ -172,7 +172,7 @@ impl Options {

info!("Starting task worker");
mas_tasks::init(
&pool,
PgRepositoryFactory::new(pool.clone()),
&mailer,
homeserver_connection.clone(),
url_builder.clone(),
Expand All @@ -193,7 +193,7 @@ impl Options {
// Initialize the activity tracker
// Activity is flushed every minute
let activity_tracker = ActivityTracker::new(
pool.clone(),
PgRepositoryFactory::new(pool.clone()).boxed(),
Duration::from_secs(60),
shutdown.task_tracker(),
shutdown.soft_shutdown_token(),
Expand All @@ -215,7 +215,7 @@ impl Options {
limiter.start();

let graphql_schema = mas_handlers::graphql_schema(
&pool,
PgRepositoryFactory::new(pool.clone()).boxed(),
&policy_factory,
homeserver_connection.clone(),
site_config.clone(),
Expand All @@ -226,7 +226,7 @@ impl Options {

let state = {
let mut s = AppState {
pool,
repository_factory: PgRepositoryFactory::new(pool),
templates,
key_store,
cookie_manager,
Expand All @@ -242,7 +242,6 @@ impl Options {
activity_tracker,
trusted_proxies,
limiter,
conn_acquisition_histogram: None,
};
s.init_metrics();
s.init_metadata_cache();
Expand Down
3 changes: 2 additions & 1 deletion crates/cli/src/commands/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use clap::Parser;
use figment::Figment;
use mas_config::{AppConfig, ConfigurationSection};
use mas_router::UrlBuilder;
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span};

use crate::{
Expand Down Expand Up @@ -63,7 +64,7 @@ impl Options {

info!("Starting task scheduler");
mas_tasks::init(
&pool,
PgRepositoryFactory::new(pool.clone()),
&mailer,
conn,
url_builder,
Expand Down
17 changes: 9 additions & 8 deletions crates/cli/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
use mas_matrix_synapse::SynapseConnection;
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::RepositoryAccess;
use mas_storage_pg::PgRepository;
use mas_storage::{BoxRepositoryFactory, RepositoryAccess, RepositoryFactory};
use mas_templates::{SiteConfigExt, Templates};
use sqlx::{
ConnectOptions, Executor, PgConnection, PgPool,
Expand Down Expand Up @@ -400,14 +399,13 @@ pub async fn database_connection_from_config_with_options(
// XXX: this could be put somewhere else?
pub async fn load_policy_factory_dynamic_data_continuously(
policy_factory: &Arc<PolicyFactory>,
pool: &PgPool,
repository_factory: BoxRepositoryFactory,
cancellation_token: CancellationToken,
task_tracker: &TaskTracker,
) -> Result<(), anyhow::Error> {
let policy_factory = policy_factory.clone();
let pool = pool.clone();

load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await?;

task_tracker.spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
Expand All @@ -420,7 +418,9 @@ pub async fn load_policy_factory_dynamic_data_continuously(
_ = interval.tick() => {}
}

if let Err(err) = load_policy_factory_dynamic_data(&policy_factory, &pool).await {
if let Err(err) =
load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await
{
tracing::error!(
error = ?err,
"Failed to load policy factory dynamic data"
Expand All @@ -438,9 +438,10 @@ pub async fn load_policy_factory_dynamic_data_continuously(
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all)]
pub async fn load_policy_factory_dynamic_data(
policy_factory: &PolicyFactory,
pool: &PgPool,
repository_factory: &(dyn RepositoryFactory + Send + Sync),
) -> Result<(), anyhow::Error> {
let mut repo = PgRepository::from_pool(pool)
let mut repo = repository_factory
.create()
.await
.context("Failed to acquire database connection")?;

Expand Down
7 changes: 3 additions & 4 deletions crates/handlers/src/activity_tracker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ use std::net::IpAddr;

use chrono::{DateTime, Utc};
use mas_data_model::{BrowserSession, CompatSession, Session};
use mas_storage::Clock;
use sqlx::PgPool;
use mas_storage::{BoxRepositoryFactory, Clock};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use ulid::Ulid;

Expand Down Expand Up @@ -61,12 +60,12 @@ impl ActivityTracker {
/// time, when the cancellation token is cancelled.
#[must_use]
pub fn new(
pool: PgPool,
repository_factory: BoxRepositoryFactory,
flush_interval: std::time::Duration,
task_tracker: &TaskTracker,
cancellation_token: CancellationToken,
) -> Self {
let worker = Worker::new(pool);
let worker = Worker::new(repository_factory);
let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
let tracker = ActivityTracker { channel: sender };

Expand Down
17 changes: 7 additions & 10 deletions crates/handlers/src/activity_tracker/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
use std::{collections::HashMap, net::IpAddr};

use chrono::{DateTime, Utc};
use mas_storage::{RepositoryAccess, RepositoryError, user::BrowserSessionRepository};
use mas_storage::{
BoxRepositoryFactory, RepositoryAccess, RepositoryError, user::BrowserSessionRepository,
};
use opentelemetry::{
Key, KeyValue,
metrics::{Counter, Gauge, Histogram},
};
use sqlx::PgPool;
use tokio_util::sync::CancellationToken;
use ulid::Ulid;

Expand Down Expand Up @@ -43,15 +44,15 @@ struct ActivityRecord {

/// Handles writing activity records to the database.
pub struct Worker {
pool: PgPool,
repository_factory: BoxRepositoryFactory,
pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
pending_records_gauge: Gauge<u64>,
message_counter: Counter<u64>,
flush_time_histogram: Histogram<u64>,
}

impl Worker {
pub(crate) fn new(pool: PgPool) -> Self {
pub(crate) fn new(repository_factory: BoxRepositoryFactory) -> Self {
let message_counter = METER
.u64_counter("mas.activity_tracker.messages")
.with_description("The number of messages received by the activity tracker")
Expand Down Expand Up @@ -89,7 +90,7 @@ impl Worker {
pending_records_gauge.record(0, &[]);

Self {
pool,
repository_factory,
pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
pending_records_gauge,
message_counter,
Expand Down Expand Up @@ -218,11 +219,7 @@ impl Worker {
#[tracing::instrument(name = "activity_tracker.flush", skip(self))]
async fn try_flush(&mut self) -> Result<(), RepositoryError> {
let pending_records = &self.pending_records;

let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
.await
.map_err(RepositoryError::from_error)?
.boxed();
let mut repo = self.repository_factory.create().await?;

let mut browser_sessions = Vec::new();
let mut oauth2_sessions = Vec::new();
Expand Down
Loading
Loading