44// SPDX-License-Identifier: AGPL-3.0-only
55// Please see LICENSE in the repository root for full details.
66
7- use std:: { convert:: Infallible , net:: IpAddr , sync:: Arc , time :: Instant } ;
7+ use std:: { convert:: Infallible , net:: IpAddr , sync:: Arc } ;
88
99use axum:: extract:: { FromRef , FromRequestParts } ;
1010use ipnetwork:: IpNetwork ;
@@ -19,10 +19,12 @@ use mas_keystore::{Encrypter, Keystore};
1919use mas_matrix:: HomeserverConnection ;
2020use mas_policy:: { Policy , PolicyFactory } ;
2121use mas_router:: UrlBuilder ;
22- use mas_storage:: { BoxClock , BoxRepository , BoxRng , SystemClock } ;
23- use mas_storage_pg:: PgRepository ;
22+ use mas_storage:: {
23+ BoxClock , BoxRepository , BoxRepositoryFactory , BoxRng , RepositoryFactory , SystemClock ,
24+ } ;
25+ use mas_storage_pg:: PgRepositoryFactory ;
2426use mas_templates:: Templates ;
25- use opentelemetry:: { KeyValue , metrics :: Histogram } ;
27+ use opentelemetry:: KeyValue ;
2628use rand:: SeedableRng ;
2729use sqlx:: PgPool ;
2830use tracing:: Instrument ;
@@ -31,7 +33,7 @@ use crate::telemetry::METER;
3133
3234#[ derive( Clone ) ]
3335pub struct AppState {
34- pub pool : PgPool ,
36+ pub repository_factory : PgRepositoryFactory ,
3537 pub templates : Templates ,
3638 pub key_store : Keystore ,
3739 pub cookie_manager : CookieManager ,
@@ -47,13 +49,12 @@ pub struct AppState {
4749 pub activity_tracker : ActivityTracker ,
4850 pub trusted_proxies : Vec < IpNetwork > ,
4951 pub limiter : Limiter ,
50- pub conn_acquisition_histogram : Option < Histogram < u64 > > ,
5152}
5253
5354impl AppState {
5455 /// Init the metrics for the app state.
5556 pub fn init_metrics ( & mut self ) {
56- let pool = self . pool . clone ( ) ;
57+ let pool = self . repository_factory . pool ( ) ;
5758 METER
5859 . i64_observable_up_down_counter ( "db.connections.usage" )
5960 . with_description ( "The number of connections that are currently in `state` described by the state attribute." )
@@ -66,7 +67,7 @@ impl AppState {
6667 } )
6768 . build ( ) ;
6869
69- let pool = self . pool . clone ( ) ;
70+ let pool = self . repository_factory . pool ( ) ;
7071 METER
7172 . i64_observable_up_down_counter ( "db.connections.max" )
7273 . with_description ( "The maximum number of open connections allowed." )
@@ -76,26 +77,18 @@ impl AppState {
7677 instrument. observe ( i64:: from ( max_conn) , & [ ] ) ;
7778 } )
7879 . build ( ) ;
79-
80- // Track the connection acquisition time
81- let histogram = METER
82- . u64_histogram ( "db.client.connections.create_time" )
83- . with_description ( "The time it took to create a new connection." )
84- . with_unit ( "ms" )
85- . build ( ) ;
86- self . conn_acquisition_histogram = Some ( histogram) ;
8780 }
8881
8982 /// Init the metadata cache in the background
9083 pub fn init_metadata_cache ( & self ) {
91- let pool = self . pool . clone ( ) ;
84+ let factory = self . repository_factory . clone ( ) ;
9285 let metadata_cache = self . metadata_cache . clone ( ) ;
9386 let http_client = self . http_client . clone ( ) ;
9487
9588 tokio:: spawn (
9689 LogContext :: new ( "metadata-cache-warmup" )
9790 . run ( async move || {
98- let conn = match pool . acquire ( ) . await {
91+ let mut repo = match factory . create ( ) . await {
9992 Ok ( conn) => conn,
10093 Err ( e) => {
10194 tracing:: error!(
@@ -106,8 +99,6 @@ impl AppState {
10699 }
107100 } ;
108101
109- let mut repo = PgRepository :: from_conn ( conn) ;
110-
111102 if let Err ( e) = metadata_cache
112103 . warm_up_and_run (
113104 & http_client,
@@ -127,9 +118,17 @@ impl AppState {
127118 }
128119}
129120
121+ // XXX(quenting): we only use this for the healthcheck endpoint, checking the db
122+ // should be part of the repository
130123impl FromRef < AppState > for PgPool {
131124 fn from_ref ( input : & AppState ) -> Self {
132- input. pool . clone ( )
125+ input. repository_factory . pool ( )
126+ }
127+ }
128+
129+ impl FromRef < AppState > for BoxRepositoryFactory {
130+ fn from_ref ( input : & AppState ) -> Self {
131+ input. repository_factory . clone ( ) . boxed ( )
133132 }
134133}
135134
@@ -359,23 +358,13 @@ impl FromRequestParts<AppState> for RequesterFingerprint {
359358}
360359
361360impl FromRequestParts < AppState > for BoxRepository {
362- type Rejection = ErrorWrapper < mas_storage_pg :: DatabaseError > ;
361+ type Rejection = ErrorWrapper < mas_storage :: RepositoryError > ;
363362
364363 async fn from_request_parts (
365364 _parts : & mut axum:: http:: request:: Parts ,
366365 state : & AppState ,
367366 ) -> Result < Self , Self :: Rejection > {
368- let start = Instant :: now ( ) ;
369- let repo = PgRepository :: from_pool ( & state. pool ) . await ?;
370-
371- // Measure the time it took to create the connection
372- let duration = start. elapsed ( ) ;
373- let duration_ms = duration. as_millis ( ) . try_into ( ) . unwrap_or ( u64:: MAX ) ;
374-
375- if let Some ( histogram) = & state. conn_acquisition_histogram {
376- histogram. record ( duration_ms, & [ ] ) ;
377- }
378-
379- Ok ( repo. boxed ( ) )
367+ let repo = state. repository_factory . create ( ) . await ?;
368+ Ok ( repo)
380369 }
381370}
0 commit comments