diff --git a/crates/cli/src/commands/syn2mas.rs b/crates/cli/src/commands/syn2mas.rs index 63fe440eb..e6ab68759 100644 --- a/crates/cli/src/commands/syn2mas.rs +++ b/crates/cli/src/commands/syn2mas.rs @@ -142,7 +142,7 @@ impl Options { .await?; } - let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(&mut mas_connection) + let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(mas_connection) .await .context("failed to issue query to lock database")? else { diff --git a/crates/syn2mas/src/mas_writer/checks.rs b/crates/syn2mas/src/mas_writer/checks.rs index a8ea1a18a..f3e98e57c 100644 --- a/crates/syn2mas/src/mas_writer/checks.rs +++ b/crates/syn2mas/src/mas_writer/checks.rs @@ -47,9 +47,7 @@ pub enum Error { /// - If we can't check whether syn2mas is already in progress on this database /// or not. #[tracing::instrument(skip_all)] -pub async fn mas_pre_migration_checks<'a>( - mas_connection: &mut LockedMasDatabase<'a>, -) -> Result<(), Error> { +pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) -> Result<(), Error> { if is_syn2mas_in_progress(mas_connection.as_mut()) .await .map_err(Error::UnableToCheckInProgress)? diff --git a/crates/syn2mas/src/mas_writer/locking.rs b/crates/syn2mas/src/mas_writer/locking.rs index f034025bf..2425ca269 100644 --- a/crates/syn2mas/src/mas_writer/locking.rs +++ b/crates/syn2mas/src/mas_writer/locking.rs @@ -15,11 +15,11 @@ static SYN2MAS_ADVISORY_LOCK: LazyLock = /// A wrapper around a Postgres connection which holds a session-wide advisory /// lock preventing concurrent access by other syn2mas instances. -pub struct LockedMasDatabase<'conn> { - inner: PgAdvisoryLockGuard<'static, &'conn mut PgConnection>, +pub struct LockedMasDatabase { + inner: PgAdvisoryLockGuard<'static, PgConnection>, } -impl<'conn> LockedMasDatabase<'conn> { +impl LockedMasDatabase { /// Attempts to lock the MAS database against concurrent access by other /// syn2mas instances. /// @@ -31,8 +31,8 @@ impl<'conn> LockedMasDatabase<'conn> { /// /// Errors are returned for underlying database errors. pub async fn try_new( - mas_connection: &'conn mut PgConnection, - ) -> Result, sqlx::Error> { + mas_connection: PgConnection, + ) -> Result, sqlx::Error> { SYN2MAS_ADVISORY_LOCK .try_acquire(mas_connection) .await @@ -48,12 +48,12 @@ impl<'conn> LockedMasDatabase<'conn> { /// # Errors /// /// Errors are returned for underlying database errors. - pub async fn unlock(self) -> Result<&'conn mut PgConnection, sqlx::Error> { + pub async fn unlock(self) -> Result { self.inner.release_now().await } } -impl AsMut for LockedMasDatabase<'_> { +impl AsMut for LockedMasDatabase { fn as_mut(&mut self) -> &mut PgConnection { self.inner.as_mut() } diff --git a/crates/syn2mas/src/mas_writer/mod.rs b/crates/syn2mas/src/mas_writer/mod.rs index f1362702c..e5260ad4f 100644 --- a/crates/syn2mas/src/mas_writer/mod.rs +++ b/crates/syn2mas/src/mas_writer/mod.rs @@ -185,8 +185,8 @@ impl WriterConnectionPool { } } -pub struct MasWriter<'c> { - conn: LockedMasDatabase<'c>, +pub struct MasWriter { + conn: LockedMasDatabase, writer_pool: WriterConnectionPool, indices_to_restore: Vec, @@ -324,7 +324,7 @@ pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result MasWriter<'conn> { +impl MasWriter { /// Creates a new MAS writer. /// /// # Errors @@ -335,7 +335,7 @@ impl<'conn> MasWriter<'conn> { #[allow(clippy::missing_panics_doc)] // not real #[tracing::instrument(skip_all)] pub async fn new( - mut conn: LockedMasDatabase<'conn>, + mut conn: LockedMasDatabase, mut writer_connections: Vec, ) -> Result { // Given that we don't have any concurrent transactions here, @@ -446,6 +446,7 @@ impl<'conn> MasWriter<'conn> { Ok(Self { conn, + writer_pool: WriterConnectionPool::new(writer_connections), indices_to_restore, constraints_to_restore, @@ -488,7 +489,7 @@ impl<'conn> MasWriter<'conn> { } async fn restore_indices( - conn: &mut LockedMasDatabase<'_>, + conn: &mut LockedMasDatabase, indices_to_restore: &[IndexDescription], constraints_to_restore: &[ConstraintDescription], ) -> Result<(), Error> { @@ -507,6 +508,7 @@ impl<'conn> MasWriter<'conn> { } /// Finish writing to the MAS database, flushing and committing all changes. + /// It returns the unlocked underlying connection. /// /// # Errors /// @@ -514,7 +516,7 @@ impl<'conn> MasWriter<'conn> { /// /// - If the database connection experiences an error. #[tracing::instrument(skip_all)] - pub async fn finish(mut self) -> Result<(), Error> { + pub async fn finish(mut self) -> Result { // Commit all writer transactions to the database. self.writer_pool .finish() @@ -549,12 +551,13 @@ impl<'conn> MasWriter<'conn> { .await .into_database("ending MAS transaction")?; - self.conn + let conn = self + .conn .unlock() .await .into_database("could not unlock MAS database")?; - Ok(()) + Ok(conn) } /// Write a batch of users to the database. @@ -1022,8 +1025,8 @@ const WRITE_BUFFER_BATCH_SIZE: usize = 4096; /// A function that can accept and flush buffers from a `MasWriteBuffer`. /// Intended uses are the methods on `MasWriter` such as `write_users`. -type WriteBufferFlusher<'conn, T> = - for<'a> fn(&'a mut MasWriter<'conn>, Vec) -> BoxFuture<'a, Result<(), Error>>; +type WriteBufferFlusher = + for<'a> fn(&'a mut MasWriter, Vec) -> BoxFuture<'a, Result<(), Error>>; /// A buffer for writing rows to the MAS database. /// Generic over the type of rows. @@ -1031,14 +1034,14 @@ type WriteBufferFlusher<'conn, T> = /// # Panics /// /// Panics if dropped before `finish()` has been called. -pub struct MasWriteBuffer<'conn, T> { +pub struct MasWriteBuffer { rows: Vec, - flusher: WriteBufferFlusher<'conn, T>, + flusher: WriteBufferFlusher, finished: bool, } -impl<'conn, T> MasWriteBuffer<'conn, T> { - pub fn new(flusher: WriteBufferFlusher<'conn, T>) -> Self { +impl MasWriteBuffer { + pub fn new(flusher: WriteBufferFlusher) -> Self { MasWriteBuffer { rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE), flusher, @@ -1046,13 +1049,13 @@ impl<'conn, T> MasWriteBuffer<'conn, T> { } } - pub async fn finish(mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> { + pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> { self.finished = true; self.flush(writer).await?; Ok(()) } - pub async fn flush(&mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> { + pub async fn flush(&mut self, writer: &mut MasWriter) -> Result<(), Error> { if self.rows.is_empty() { return Ok(()); } @@ -1062,7 +1065,7 @@ impl<'conn, T> MasWriteBuffer<'conn, T> { Ok(()) } - pub async fn write(&mut self, writer: &mut MasWriter<'conn>, row: T) -> Result<(), Error> { + pub async fn write(&mut self, writer: &mut MasWriter, row: T) -> Result<(), Error> { self.rows.push(row); if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE { self.flush(writer).await?; @@ -1071,7 +1074,7 @@ impl<'conn, T> MasWriteBuffer<'conn, T> { } } -impl Drop for MasWriteBuffer<'_, T> { +impl Drop for MasWriteBuffer { fn drop(&mut self) { assert!(self.finished, "MasWriteBuffer dropped but not finished!"); } @@ -1180,10 +1183,8 @@ mod test { /// Runs some code with a `MasWriter`. /// /// The callback is responsible for `finish`ing the `MasWriter`. - async fn make_mas_writer<'conn>( - pool: &PgPool, - main_conn: &'conn mut PgConnection, - ) -> MasWriter<'conn> { + async fn make_mas_writer(pool: &PgPool) -> MasWriter { + let main_conn = pool.acquire().await.unwrap().detach(); let mut writer_conns = Vec::new(); for _ in 0..2 { writer_conns.push( @@ -1205,8 +1206,7 @@ mod test { /// Tests writing a single user, without a password. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1220,7 +1220,7 @@ mod test { .await .expect("failed to write user"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1230,8 +1230,7 @@ mod test { async fn test_write_user_with_password(pool: PgPool) { const USER_ID: Uuid = Uuid::from_u128(1u128); - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1254,7 +1253,7 @@ mod test { .await .expect("failed to write password"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1262,8 +1261,7 @@ mod test { /// Tests writing a single user, with an e-mail address associated. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_email(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1287,7 +1285,7 @@ mod test { .await .expect("failed to write e-mail"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1296,8 +1294,7 @@ mod test { /// associated. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_unsupported_threepid(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1321,7 +1318,7 @@ mod test { .await .expect("failed to write phone number (unsupported threepid)"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1331,8 +1328,7 @@ mod test { /// real migration, this is done by running a provider sync first. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))] async fn test_write_user_with_upstream_provider_link(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1357,7 +1353,7 @@ mod test { .await .expect("failed to write link"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1365,8 +1361,7 @@ mod test { /// Tests writing a single user, with a device (compat session). #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_device(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1395,7 +1390,7 @@ mod test { .await .expect("failed to write compat session"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1403,8 +1398,7 @@ mod test { /// Tests writing a single user, with a device and an access token. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_access_token(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1444,7 +1438,7 @@ mod test { .await .expect("failed to write access token"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1453,8 +1447,7 @@ mod test { /// refresh token. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_refresh_token(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1505,7 +1498,7 @@ mod test { .await .expect("failed to write refresh token"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } diff --git a/crates/syn2mas/src/migration.rs b/crates/syn2mas/src/migration.rs index f46586c03..18d6746be 100644 --- a/crates/syn2mas/src/migration.rs +++ b/crates/syn2mas/src/migration.rs @@ -92,7 +92,7 @@ struct UsersMigrated { #[allow(clippy::implicit_hasher)] pub async fn migrate( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, clock: &dyn Clock, rng: &mut impl RngCore, @@ -179,7 +179,7 @@ pub async fn migrate( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_users( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, user_count_hint: usize, server_name: &str, rng: &mut impl RngCore, @@ -232,7 +232,7 @@ async fn migrate_users( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_threepids( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, rng: &mut impl RngCore, user_localparts_to_uuid: &HashMap, @@ -315,7 +315,7 @@ async fn migrate_threepids( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_external_ids( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, rng: &mut impl RngCore, user_localparts_to_uuid: &HashMap, @@ -391,7 +391,7 @@ async fn migrate_external_ids( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_devices( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, rng: &mut impl RngCore, user_localparts_to_uuid: &HashMap, @@ -483,7 +483,7 @@ async fn migrate_devices( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_unrefreshable_access_tokens( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, clock: &dyn Clock, rng: &mut impl RngCore, @@ -591,7 +591,7 @@ async fn migrate_unrefreshable_access_tokens( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_refreshable_token_pairs( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, clock: &dyn Clock, rng: &mut impl RngCore,