Skip to content

Commit e9589ae

Browse files
authored
Don't hold database connections open when talking to the homeserver (#4527)
2 parents e99d621 + 8d7be72 commit e9589ae

File tree

20 files changed

+257
-148
lines changed

20 files changed

+257
-148
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/cli/src/app_state.rs

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// SPDX-License-Identifier: AGPL-3.0-only
55
// Please see LICENSE in the repository root for full details.
66

7-
use std::{convert::Infallible, net::IpAddr, sync::Arc, time::Instant};
7+
use std::{convert::Infallible, net::IpAddr, sync::Arc};
88

99
use axum::extract::{FromRef, FromRequestParts};
1010
use ipnetwork::IpNetwork;
@@ -19,10 +19,12 @@ use mas_keystore::{Encrypter, Keystore};
1919
use mas_matrix::HomeserverConnection;
2020
use mas_policy::{Policy, PolicyFactory};
2121
use mas_router::UrlBuilder;
22-
use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock};
23-
use mas_storage_pg::PgRepository;
22+
use mas_storage::{
23+
BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, RepositoryFactory, SystemClock,
24+
};
25+
use mas_storage_pg::PgRepositoryFactory;
2426
use mas_templates::Templates;
25-
use opentelemetry::{KeyValue, metrics::Histogram};
27+
use opentelemetry::KeyValue;
2628
use rand::SeedableRng;
2729
use sqlx::PgPool;
2830
use tracing::Instrument;
@@ -31,7 +33,7 @@ use crate::telemetry::METER;
3133

3234
#[derive(Clone)]
3335
pub struct AppState {
34-
pub pool: PgPool,
36+
pub repository_factory: PgRepositoryFactory,
3537
pub templates: Templates,
3638
pub key_store: Keystore,
3739
pub cookie_manager: CookieManager,
@@ -47,13 +49,12 @@ pub struct AppState {
4749
pub activity_tracker: ActivityTracker,
4850
pub trusted_proxies: Vec<IpNetwork>,
4951
pub limiter: Limiter,
50-
pub conn_acquisition_histogram: Option<Histogram<u64>>,
5152
}
5253

5354
impl AppState {
5455
/// Init the metrics for the app state.
5556
pub fn init_metrics(&mut self) {
56-
let pool = self.pool.clone();
57+
let pool = self.repository_factory.pool();
5758
METER
5859
.i64_observable_up_down_counter("db.connections.usage")
5960
.with_description("The number of connections that are currently in `state` described by the state attribute.")
@@ -66,7 +67,7 @@ impl AppState {
6667
})
6768
.build();
6869

69-
let pool = self.pool.clone();
70+
let pool = self.repository_factory.pool();
7071
METER
7172
.i64_observable_up_down_counter("db.connections.max")
7273
.with_description("The maximum number of open connections allowed.")
@@ -76,26 +77,18 @@ impl AppState {
7677
instrument.observe(i64::from(max_conn), &[]);
7778
})
7879
.build();
79-
80-
// Track the connection acquisition time
81-
let histogram = METER
82-
.u64_histogram("db.client.connections.create_time")
83-
.with_description("The time it took to create a new connection.")
84-
.with_unit("ms")
85-
.build();
86-
self.conn_acquisition_histogram = Some(histogram);
8780
}
8881

8982
/// Init the metadata cache in the background
9083
pub fn init_metadata_cache(&self) {
91-
let pool = self.pool.clone();
84+
let factory = self.repository_factory.clone();
9285
let metadata_cache = self.metadata_cache.clone();
9386
let http_client = self.http_client.clone();
9487

9588
tokio::spawn(
9689
LogContext::new("metadata-cache-warmup")
9790
.run(async move || {
98-
let conn = match pool.acquire().await {
91+
let mut repo = match factory.create().await {
9992
Ok(conn) => conn,
10093
Err(e) => {
10194
tracing::error!(
@@ -106,8 +99,6 @@ impl AppState {
10699
}
107100
};
108101

109-
let mut repo = PgRepository::from_conn(conn);
110-
111102
if let Err(e) = metadata_cache
112103
.warm_up_and_run(
113104
&http_client,
@@ -127,9 +118,17 @@ impl AppState {
127118
}
128119
}
129120

121+
// XXX(quenting): we only use this for the healthcheck endpoint, checking the db
122+
// should be part of the repository
130123
impl FromRef<AppState> for PgPool {
131124
fn from_ref(input: &AppState) -> Self {
132-
input.pool.clone()
125+
input.repository_factory.pool()
126+
}
127+
}
128+
129+
impl FromRef<AppState> for BoxRepositoryFactory {
130+
fn from_ref(input: &AppState) -> Self {
131+
input.repository_factory.clone().boxed()
133132
}
134133
}
135134

@@ -359,23 +358,13 @@ impl FromRequestParts<AppState> for RequesterFingerprint {
359358
}
360359

361360
impl FromRequestParts<AppState> for BoxRepository {
362-
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
361+
type Rejection = ErrorWrapper<mas_storage::RepositoryError>;
363362

364363
async fn from_request_parts(
365364
_parts: &mut axum::http::request::Parts,
366365
state: &AppState,
367366
) -> Result<Self, Self::Rejection> {
368-
let start = Instant::now();
369-
let repo = PgRepository::from_pool(&state.pool).await?;
370-
371-
// Measure the time it took to create the connection
372-
let duration = start.elapsed();
373-
let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
374-
375-
if let Some(histogram) = &state.conn_acquisition_histogram {
376-
histogram.record(duration_ms, &[]);
377-
}
378-
379-
Ok(repo.boxed())
367+
let repo = state.repository_factory.create().await?;
368+
Ok(repo)
380369
}
381370
}

crates/cli/src/commands/debug.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use figment::Figment;
1111
use mas_config::{
1212
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
1313
};
14+
use mas_storage_pg::PgRepositoryFactory;
1415
use tracing::{info, info_span};
1516

1617
use crate::util::{
@@ -48,7 +49,8 @@ impl Options {
4849
if with_dynamic_data {
4950
let database_config = DatabaseConfig::extract(figment)?;
5051
let pool = database_pool_from_config(&database_config).await?;
51-
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
52+
let repository_factory = PgRepositoryFactory::new(pool.clone());
53+
load_policy_factory_dynamic_data(&policy_factory, &repository_factory).await?;
5254
}
5355

5456
let _instance = policy_factory.instantiate().await?;

crates/cli/src/commands/server.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
1818
use mas_listener::server::Server;
1919
use mas_router::UrlBuilder;
2020
use mas_storage::SystemClock;
21-
use mas_storage_pg::MIGRATOR;
21+
use mas_storage_pg::{MIGRATOR, PgRepositoryFactory};
2222
use sqlx::migrate::Migrate;
2323
use tracing::{Instrument, info, info_span, warn};
2424

@@ -134,7 +134,7 @@ impl Options {
134134

135135
load_policy_factory_dynamic_data_continuously(
136136
&policy_factory,
137-
&pool,
137+
PgRepositoryFactory::new(pool.clone()).boxed(),
138138
shutdown.soft_shutdown_token(),
139139
shutdown.task_tracker(),
140140
)
@@ -172,7 +172,7 @@ impl Options {
172172

173173
info!("Starting task worker");
174174
mas_tasks::init(
175-
&pool,
175+
PgRepositoryFactory::new(pool.clone()),
176176
&mailer,
177177
homeserver_connection.clone(),
178178
url_builder.clone(),
@@ -193,7 +193,7 @@ impl Options {
193193
// Initialize the activity tracker
194194
// Activity is flushed every minute
195195
let activity_tracker = ActivityTracker::new(
196-
pool.clone(),
196+
PgRepositoryFactory::new(pool.clone()).boxed(),
197197
Duration::from_secs(60),
198198
shutdown.task_tracker(),
199199
shutdown.soft_shutdown_token(),
@@ -215,7 +215,7 @@ impl Options {
215215
limiter.start();
216216

217217
let graphql_schema = mas_handlers::graphql_schema(
218-
&pool,
218+
PgRepositoryFactory::new(pool.clone()).boxed(),
219219
&policy_factory,
220220
homeserver_connection.clone(),
221221
site_config.clone(),
@@ -226,7 +226,7 @@ impl Options {
226226

227227
let state = {
228228
let mut s = AppState {
229-
pool,
229+
repository_factory: PgRepositoryFactory::new(pool),
230230
templates,
231231
key_store,
232232
cookie_manager,
@@ -242,7 +242,6 @@ impl Options {
242242
activity_tracker,
243243
trusted_proxies,
244244
limiter,
245-
conn_acquisition_histogram: None,
246245
};
247246
s.init_metrics();
248247
s.init_metadata_cache();

crates/cli/src/commands/worker.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use clap::Parser;
1010
use figment::Figment;
1111
use mas_config::{AppConfig, ConfigurationSection};
1212
use mas_router::UrlBuilder;
13+
use mas_storage_pg::PgRepositoryFactory;
1314
use tracing::{info, info_span};
1415

1516
use crate::{
@@ -63,7 +64,7 @@ impl Options {
6364

6465
info!("Starting task scheduler");
6566
mas_tasks::init(
66-
&pool,
67+
PgRepositoryFactory::new(pool.clone()),
6768
&mailer,
6869
conn,
6970
url_builder,

crates/cli/src/util.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
2020
use mas_matrix_synapse::SynapseConnection;
2121
use mas_policy::PolicyFactory;
2222
use mas_router::UrlBuilder;
23-
use mas_storage::RepositoryAccess;
24-
use mas_storage_pg::PgRepository;
23+
use mas_storage::{BoxRepositoryFactory, RepositoryAccess, RepositoryFactory};
2524
use mas_templates::{SiteConfigExt, Templates};
2625
use sqlx::{
2726
ConnectOptions, Executor, PgConnection, PgPool,
@@ -400,14 +399,13 @@ pub async fn database_connection_from_config_with_options(
400399
// XXX: this could be put somewhere else?
401400
pub async fn load_policy_factory_dynamic_data_continuously(
402401
policy_factory: &Arc<PolicyFactory>,
403-
pool: &PgPool,
402+
repository_factory: BoxRepositoryFactory,
404403
cancellation_token: CancellationToken,
405404
task_tracker: &TaskTracker,
406405
) -> Result<(), anyhow::Error> {
407406
let policy_factory = policy_factory.clone();
408-
let pool = pool.clone();
409407

410-
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
408+
load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await?;
411409

412410
task_tracker.spawn(async move {
413411
let mut interval = tokio::time::interval(Duration::from_secs(60));
@@ -420,7 +418,9 @@ pub async fn load_policy_factory_dynamic_data_continuously(
420418
_ = interval.tick() => {}
421419
}
422420

423-
if let Err(err) = load_policy_factory_dynamic_data(&policy_factory, &pool).await {
421+
if let Err(err) =
422+
load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await
423+
{
424424
tracing::error!(
425425
error = ?err,
426426
"Failed to load policy factory dynamic data"
@@ -438,9 +438,10 @@ pub async fn load_policy_factory_dynamic_data_continuously(
438438
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all)]
439439
pub async fn load_policy_factory_dynamic_data(
440440
policy_factory: &PolicyFactory,
441-
pool: &PgPool,
441+
repository_factory: &(dyn RepositoryFactory + Send + Sync),
442442
) -> Result<(), anyhow::Error> {
443-
let mut repo = PgRepository::from_pool(pool)
443+
let mut repo = repository_factory
444+
.create()
444445
.await
445446
.context("Failed to acquire database connection")?;
446447

crates/handlers/src/activity_tracker/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ use std::net::IpAddr;
1111

1212
use chrono::{DateTime, Utc};
1313
use mas_data_model::{BrowserSession, CompatSession, Session};
14-
use mas_storage::Clock;
15-
use sqlx::PgPool;
14+
use mas_storage::{BoxRepositoryFactory, Clock};
1615
use tokio_util::{sync::CancellationToken, task::TaskTracker};
1716
use ulid::Ulid;
1817

@@ -61,12 +60,12 @@ impl ActivityTracker {
6160
/// time, when the cancellation token is cancelled.
6261
#[must_use]
6362
pub fn new(
64-
pool: PgPool,
63+
repository_factory: BoxRepositoryFactory,
6564
flush_interval: std::time::Duration,
6665
task_tracker: &TaskTracker,
6766
cancellation_token: CancellationToken,
6867
) -> Self {
69-
let worker = Worker::new(pool);
68+
let worker = Worker::new(repository_factory);
7069
let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
7170
let tracker = ActivityTracker { channel: sender };
7271

crates/handlers/src/activity_tracker/worker.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
use std::{collections::HashMap, net::IpAddr};
88

99
use chrono::{DateTime, Utc};
10-
use mas_storage::{RepositoryAccess, RepositoryError, user::BrowserSessionRepository};
10+
use mas_storage::{
11+
BoxRepositoryFactory, RepositoryAccess, RepositoryError, user::BrowserSessionRepository,
12+
};
1113
use opentelemetry::{
1214
Key, KeyValue,
1315
metrics::{Counter, Gauge, Histogram},
1416
};
15-
use sqlx::PgPool;
1617
use tokio_util::sync::CancellationToken;
1718
use ulid::Ulid;
1819

@@ -43,15 +44,15 @@ struct ActivityRecord {
4344

4445
/// Handles writing activity records to the database.
4546
pub struct Worker {
46-
pool: PgPool,
47+
repository_factory: BoxRepositoryFactory,
4748
pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
4849
pending_records_gauge: Gauge<u64>,
4950
message_counter: Counter<u64>,
5051
flush_time_histogram: Histogram<u64>,
5152
}
5253

5354
impl Worker {
54-
pub(crate) fn new(pool: PgPool) -> Self {
55+
pub(crate) fn new(repository_factory: BoxRepositoryFactory) -> Self {
5556
let message_counter = METER
5657
.u64_counter("mas.activity_tracker.messages")
5758
.with_description("The number of messages received by the activity tracker")
@@ -89,7 +90,7 @@ impl Worker {
8990
pending_records_gauge.record(0, &[]);
9091

9192
Self {
92-
pool,
93+
repository_factory,
9394
pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
9495
pending_records_gauge,
9596
message_counter,
@@ -218,11 +219,7 @@ impl Worker {
218219
#[tracing::instrument(name = "activity_tracker.flush", skip(self))]
219220
async fn try_flush(&mut self) -> Result<(), RepositoryError> {
220221
let pending_records = &self.pending_records;
221-
222-
let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
223-
.await
224-
.map_err(RepositoryError::from_error)?
225-
.boxed();
222+
let mut repo = self.repository_factory.create().await?;
226223

227224
let mut browser_sessions = Vec::new();
228225
let mut oauth2_sessions = Vec::new();

0 commit comments

Comments
 (0)