@@ -34,8 +34,11 @@ use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
3434use mas_matrix:: { HomeserverConnection , MockHomeserverConnection } ;
3535use mas_policy:: { InstantiateError , Policy , PolicyFactory } ;
3636use mas_router:: { SimpleRoute , UrlBuilder } ;
37- use mas_storage:: { BoxClock , BoxRepository , BoxRng , clock:: MockClock } ;
38- use mas_storage_pg:: { DatabaseError , PgRepository } ;
37+ use mas_storage:: {
38+ BoxClock , BoxRepository , BoxRepositoryFactory , BoxRng , RepositoryError , RepositoryFactory ,
39+ clock:: MockClock ,
40+ } ;
41+ use mas_storage_pg:: PgRepositoryFactory ;
3942use mas_templates:: { SiteConfigExt , Templates } ;
4043use oauth2_types:: { registration:: ClientRegistrationResponse , requests:: AccessTokenResponse } ;
4144use rand:: SeedableRng ;
@@ -92,7 +95,7 @@ pub(crate) async fn policy_factory(
9295
9396#[ derive( Clone ) ]
9497pub ( crate ) struct TestState {
95- pub pool : PgPool ,
98+ pub repository_factory : PgRepositoryFactory ,
9699 pub templates : Templates ,
97100 pub key_store : Keystore ,
98101 pub cookie_manager : CookieManager ,
@@ -209,7 +212,7 @@ impl TestState {
209212 let limiter = Limiter :: new ( & RateLimitingConfig :: default ( ) ) . unwrap ( ) ;
210213
211214 let graphql_state = TestGraphQLState {
212- pool : pool. clone ( ) ,
215+ repository_factory : PgRepositoryFactory :: new ( pool. clone ( ) ) . boxed ( ) ,
213216 policy_factory : Arc :: clone ( & policy_factory) ,
214217 homeserver_connection : Arc :: clone ( & homeserver_connection) ,
215218 site_config : site_config. clone ( ) ,
@@ -224,14 +227,14 @@ impl TestState {
224227 let graphql_schema = graphql:: schema_builder ( ) . data ( state) . finish ( ) ;
225228
226229 let activity_tracker = ActivityTracker :: new (
227- pool. clone ( ) ,
230+ PgRepositoryFactory :: new ( pool. clone ( ) ) . boxed ( ) ,
228231 std:: time:: Duration :: from_secs ( 60 ) ,
229232 & task_tracker,
230233 shutdown_token. child_token ( ) ,
231234 ) ;
232235
233236 Ok ( Self {
234- pool,
237+ repository_factory : PgRepositoryFactory :: new ( pool) ,
235238 templates,
236239 key_store,
237240 cookie_manager,
@@ -256,7 +259,7 @@ impl TestState {
256259 /// Reset the test utils to a fresh state, with the same configuration.
257260 pub async fn reset ( self ) -> Self {
258261 let site_config = self . site_config . clone ( ) ;
259- let pool = self . pool . clone ( ) ;
262+ let pool = self . repository_factory . pool ( ) ;
260263 let task_tracker = self . task_tracker . clone ( ) ;
261264
262265 // This should trigger the cancellation drop guard
@@ -351,9 +354,8 @@ impl TestState {
351354 access_token
352355 }
353356
354- pub async fn repository ( & self ) -> Result < BoxRepository , DatabaseError > {
355- let repo = PgRepository :: from_pool ( & self . pool ) . await ?;
356- Ok ( repo. boxed ( ) )
357+ pub async fn repository ( & self ) -> Result < BoxRepository , RepositoryError > {
358+ self . repository_factory . create ( ) . await
357359 }
358360
359361 /// Returns a new random number generator.
@@ -393,7 +395,7 @@ impl TestState {
393395}
394396
395397struct TestGraphQLState {
396- pool : PgPool ,
398+ repository_factory : BoxRepositoryFactory ,
397399 homeserver_connection : Arc < MockHomeserverConnection > ,
398400 site_config : SiteConfig ,
399401 policy_factory : Arc < PolicyFactory > ,
@@ -407,11 +409,7 @@ struct TestGraphQLState {
407409#[ async_trait:: async_trait]
408410impl graphql:: State for TestGraphQLState {
409411 async fn repository ( & self ) -> Result < BoxRepository , mas_storage:: RepositoryError > {
410- let repo = PgRepository :: from_pool ( & self . pool )
411- . await
412- . map_err ( mas_storage:: RepositoryError :: from_error) ?;
413-
414- Ok ( repo. boxed ( ) )
412+ self . repository_factory . create ( ) . await
415413 }
416414
417415 async fn policy ( & self ) -> Result < Policy , InstantiateError > {
@@ -451,7 +449,7 @@ impl graphql::State for TestGraphQLState {
451449
452450impl FromRef < TestState > for PgPool {
453451 fn from_ref ( input : & TestState ) -> Self {
454- input. pool . clone ( )
452+ input. repository_factory . pool ( )
455453 }
456454}
457455
@@ -598,14 +596,14 @@ impl FromRequestParts<TestState> for BoxRng {
598596}
599597
600598impl FromRequestParts < TestState > for BoxRepository {
601- type Rejection = ErrorWrapper < mas_storage_pg :: DatabaseError > ;
599+ type Rejection = ErrorWrapper < RepositoryError > ;
602600
603601 async fn from_request_parts (
604602 _parts : & mut axum:: http:: request:: Parts ,
605603 state : & TestState ,
606604 ) -> Result < Self , Self :: Rejection > {
607- let repo = PgRepository :: from_pool ( & state. pool ) . await ?;
608- Ok ( repo. boxed ( ) )
605+ let repo = state. repository_factory . create ( ) . await ?;
606+ Ok ( repo)
609607 }
610608}
611609
0 commit comments