From e46cd724f4c9f08e56d6966450152019828df38b Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 7 Feb 2025 12:53:51 +0100 Subject: [PATCH 1/2] Make the MAS connection owned in the locked database struct --- crates/cli/src/commands/syn2mas.rs | 2 +- crates/syn2mas/src/mas_writer/checks.rs | 4 +- crates/syn2mas/src/mas_writer/locking.rs | 14 ++--- crates/syn2mas/src/mas_writer/mod.rs | 65 +++++++++++------------- 4 files changed, 39 insertions(+), 46 deletions(-) diff --git a/crates/cli/src/commands/syn2mas.rs b/crates/cli/src/commands/syn2mas.rs index 63fe440eb..e6ab68759 100644 --- a/crates/cli/src/commands/syn2mas.rs +++ b/crates/cli/src/commands/syn2mas.rs @@ -142,7 +142,7 @@ impl Options { .await?; } - let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(&mut mas_connection) + let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(mas_connection) .await .context("failed to issue query to lock database")? else { diff --git a/crates/syn2mas/src/mas_writer/checks.rs b/crates/syn2mas/src/mas_writer/checks.rs index a8ea1a18a..f3e98e57c 100644 --- a/crates/syn2mas/src/mas_writer/checks.rs +++ b/crates/syn2mas/src/mas_writer/checks.rs @@ -47,9 +47,7 @@ pub enum Error { /// - If we can't check whether syn2mas is already in progress on this database /// or not. #[tracing::instrument(skip_all)] -pub async fn mas_pre_migration_checks<'a>( - mas_connection: &mut LockedMasDatabase<'a>, -) -> Result<(), Error> { +pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) -> Result<(), Error> { if is_syn2mas_in_progress(mas_connection.as_mut()) .await .map_err(Error::UnableToCheckInProgress)? diff --git a/crates/syn2mas/src/mas_writer/locking.rs b/crates/syn2mas/src/mas_writer/locking.rs index f034025bf..2425ca269 100644 --- a/crates/syn2mas/src/mas_writer/locking.rs +++ b/crates/syn2mas/src/mas_writer/locking.rs @@ -15,11 +15,11 @@ static SYN2MAS_ADVISORY_LOCK: LazyLock = /// A wrapper around a Postgres connection which holds a session-wide advisory /// lock preventing concurrent access by other syn2mas instances. -pub struct LockedMasDatabase<'conn> { - inner: PgAdvisoryLockGuard<'static, &'conn mut PgConnection>, +pub struct LockedMasDatabase { + inner: PgAdvisoryLockGuard<'static, PgConnection>, } -impl<'conn> LockedMasDatabase<'conn> { +impl LockedMasDatabase { /// Attempts to lock the MAS database against concurrent access by other /// syn2mas instances. /// @@ -31,8 +31,8 @@ impl<'conn> LockedMasDatabase<'conn> { /// /// Errors are returned for underlying database errors. pub async fn try_new( - mas_connection: &'conn mut PgConnection, - ) -> Result, sqlx::Error> { + mas_connection: PgConnection, + ) -> Result, sqlx::Error> { SYN2MAS_ADVISORY_LOCK .try_acquire(mas_connection) .await @@ -48,12 +48,12 @@ impl<'conn> LockedMasDatabase<'conn> { /// # Errors /// /// Errors are returned for underlying database errors. - pub async fn unlock(self) -> Result<&'conn mut PgConnection, sqlx::Error> { + pub async fn unlock(self) -> Result { self.inner.release_now().await } } -impl AsMut for LockedMasDatabase<'_> { +impl AsMut for LockedMasDatabase { fn as_mut(&mut self) -> &mut PgConnection { self.inner.as_mut() } diff --git a/crates/syn2mas/src/mas_writer/mod.rs b/crates/syn2mas/src/mas_writer/mod.rs index f1362702c..a28821e31 100644 --- a/crates/syn2mas/src/mas_writer/mod.rs +++ b/crates/syn2mas/src/mas_writer/mod.rs @@ -186,7 +186,9 @@ impl WriterConnectionPool { } pub struct MasWriter<'c> { - conn: LockedMasDatabase<'c>, + conn: LockedMasDatabase, + // Temporary phantom data, so that we don't remove the lifetime parameter yet + phantom: std::marker::PhantomData<&'c ()>, writer_pool: WriterConnectionPool, indices_to_restore: Vec, @@ -324,7 +326,7 @@ pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result MasWriter<'conn> { +impl MasWriter<'_> { /// Creates a new MAS writer. /// /// # Errors @@ -335,7 +337,7 @@ impl<'conn> MasWriter<'conn> { #[allow(clippy::missing_panics_doc)] // not real #[tracing::instrument(skip_all)] pub async fn new( - mut conn: LockedMasDatabase<'conn>, + mut conn: LockedMasDatabase, mut writer_connections: Vec, ) -> Result { // Given that we don't have any concurrent transactions here, @@ -446,6 +448,7 @@ impl<'conn> MasWriter<'conn> { Ok(Self { conn, + phantom: std::marker::PhantomData, writer_pool: WriterConnectionPool::new(writer_connections), indices_to_restore, constraints_to_restore, @@ -488,7 +491,7 @@ impl<'conn> MasWriter<'conn> { } async fn restore_indices( - conn: &mut LockedMasDatabase<'_>, + conn: &mut LockedMasDatabase, indices_to_restore: &[IndexDescription], constraints_to_restore: &[ConstraintDescription], ) -> Result<(), Error> { @@ -507,6 +510,7 @@ impl<'conn> MasWriter<'conn> { } /// Finish writing to the MAS database, flushing and committing all changes. + /// It returns the unlocked underlying connection. /// /// # Errors /// @@ -514,7 +518,7 @@ impl<'conn> MasWriter<'conn> { /// /// - If the database connection experiences an error. #[tracing::instrument(skip_all)] - pub async fn finish(mut self) -> Result<(), Error> { + pub async fn finish(mut self) -> Result { // Commit all writer transactions to the database. self.writer_pool .finish() @@ -549,12 +553,13 @@ impl<'conn> MasWriter<'conn> { .await .into_database("ending MAS transaction")?; - self.conn + let conn = self + .conn .unlock() .await .into_database("could not unlock MAS database")?; - Ok(()) + Ok(conn) } /// Write a batch of users to the database. @@ -1180,10 +1185,8 @@ mod test { /// Runs some code with a `MasWriter`. /// /// The callback is responsible for `finish`ing the `MasWriter`. - async fn make_mas_writer<'conn>( - pool: &PgPool, - main_conn: &'conn mut PgConnection, - ) -> MasWriter<'conn> { + async fn make_mas_writer(pool: &PgPool) -> MasWriter<'static> { + let main_conn = pool.acquire().await.unwrap().detach(); let mut writer_conns = Vec::new(); for _ in 0..2 { writer_conns.push( @@ -1205,8 +1208,7 @@ mod test { /// Tests writing a single user, without a password. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1220,7 +1222,7 @@ mod test { .await .expect("failed to write user"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1230,8 +1232,7 @@ mod test { async fn test_write_user_with_password(pool: PgPool) { const USER_ID: Uuid = Uuid::from_u128(1u128); - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1254,7 +1255,7 @@ mod test { .await .expect("failed to write password"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1262,8 +1263,7 @@ mod test { /// Tests writing a single user, with an e-mail address associated. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_email(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1287,7 +1287,7 @@ mod test { .await .expect("failed to write e-mail"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1296,8 +1296,7 @@ mod test { /// associated. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_unsupported_threepid(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1321,7 +1320,7 @@ mod test { .await .expect("failed to write phone number (unsupported threepid)"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1331,8 +1330,7 @@ mod test { /// real migration, this is done by running a provider sync first. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))] async fn test_write_user_with_upstream_provider_link(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1357,7 +1355,7 @@ mod test { .await .expect("failed to write link"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1365,8 +1363,7 @@ mod test { /// Tests writing a single user, with a device (compat session). #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_device(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1395,7 +1392,7 @@ mod test { .await .expect("failed to write compat session"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1403,8 +1400,7 @@ mod test { /// Tests writing a single user, with a device and an access token. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_access_token(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1444,7 +1440,7 @@ mod test { .await .expect("failed to write access token"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } @@ -1453,8 +1449,7 @@ mod test { /// refresh token. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_write_user_with_refresh_token(pool: PgPool) { - let mut conn = pool.acquire().await.unwrap(); - let mut writer = make_mas_writer(&pool, &mut conn).await; + let mut writer = make_mas_writer(&pool).await; writer .write_users(vec![MasNewUser { @@ -1505,7 +1500,7 @@ mod test { .await .expect("failed to write refresh token"); - writer.finish().await.expect("failed to finish MasWriter"); + let mut conn = writer.finish().await.expect("failed to finish MasWriter"); assert_db_snapshot!(&mut conn); } From 919c86c86e49cdd1df4b9b15f47cbd6c10b6e71a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 7 Feb 2025 12:57:21 +0100 Subject: [PATCH 2/2] Remove the lifetime parameter from MasWriter --- crates/syn2mas/src/mas_writer/mod.rs | 30 +++++++++++++--------------- crates/syn2mas/src/migration.rs | 14 ++++++------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/crates/syn2mas/src/mas_writer/mod.rs b/crates/syn2mas/src/mas_writer/mod.rs index a28821e31..e5260ad4f 100644 --- a/crates/syn2mas/src/mas_writer/mod.rs +++ b/crates/syn2mas/src/mas_writer/mod.rs @@ -185,10 +185,8 @@ impl WriterConnectionPool { } } -pub struct MasWriter<'c> { +pub struct MasWriter { conn: LockedMasDatabase, - // Temporary phantom data, so that we don't remove the lifetime parameter yet - phantom: std::marker::PhantomData<&'c ()>, writer_pool: WriterConnectionPool, indices_to_restore: Vec, @@ -326,7 +324,7 @@ pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result { +impl MasWriter { /// Creates a new MAS writer. /// /// # Errors @@ -448,7 +446,7 @@ impl MasWriter<'_> { Ok(Self { conn, - phantom: std::marker::PhantomData, + writer_pool: WriterConnectionPool::new(writer_connections), indices_to_restore, constraints_to_restore, @@ -1027,8 +1025,8 @@ const WRITE_BUFFER_BATCH_SIZE: usize = 4096; /// A function that can accept and flush buffers from a `MasWriteBuffer`. /// Intended uses are the methods on `MasWriter` such as `write_users`. -type WriteBufferFlusher<'conn, T> = - for<'a> fn(&'a mut MasWriter<'conn>, Vec) -> BoxFuture<'a, Result<(), Error>>; +type WriteBufferFlusher = + for<'a> fn(&'a mut MasWriter, Vec) -> BoxFuture<'a, Result<(), Error>>; /// A buffer for writing rows to the MAS database. /// Generic over the type of rows. @@ -1036,14 +1034,14 @@ type WriteBufferFlusher<'conn, T> = /// # Panics /// /// Panics if dropped before `finish()` has been called. -pub struct MasWriteBuffer<'conn, T> { +pub struct MasWriteBuffer { rows: Vec, - flusher: WriteBufferFlusher<'conn, T>, + flusher: WriteBufferFlusher, finished: bool, } -impl<'conn, T> MasWriteBuffer<'conn, T> { - pub fn new(flusher: WriteBufferFlusher<'conn, T>) -> Self { +impl MasWriteBuffer { + pub fn new(flusher: WriteBufferFlusher) -> Self { MasWriteBuffer { rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE), flusher, @@ -1051,13 +1049,13 @@ impl<'conn, T> MasWriteBuffer<'conn, T> { } } - pub async fn finish(mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> { + pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> { self.finished = true; self.flush(writer).await?; Ok(()) } - pub async fn flush(&mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> { + pub async fn flush(&mut self, writer: &mut MasWriter) -> Result<(), Error> { if self.rows.is_empty() { return Ok(()); } @@ -1067,7 +1065,7 @@ impl<'conn, T> MasWriteBuffer<'conn, T> { Ok(()) } - pub async fn write(&mut self, writer: &mut MasWriter<'conn>, row: T) -> Result<(), Error> { + pub async fn write(&mut self, writer: &mut MasWriter, row: T) -> Result<(), Error> { self.rows.push(row); if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE { self.flush(writer).await?; @@ -1076,7 +1074,7 @@ impl<'conn, T> MasWriteBuffer<'conn, T> { } } -impl Drop for MasWriteBuffer<'_, T> { +impl Drop for MasWriteBuffer { fn drop(&mut self) { assert!(self.finished, "MasWriteBuffer dropped but not finished!"); } @@ -1185,7 +1183,7 @@ mod test { /// Runs some code with a `MasWriter`. /// /// The callback is responsible for `finish`ing the `MasWriter`. - async fn make_mas_writer(pool: &PgPool) -> MasWriter<'static> { + async fn make_mas_writer(pool: &PgPool) -> MasWriter { let main_conn = pool.acquire().await.unwrap().detach(); let mut writer_conns = Vec::new(); for _ in 0..2 { diff --git a/crates/syn2mas/src/migration.rs b/crates/syn2mas/src/migration.rs index f46586c03..18d6746be 100644 --- a/crates/syn2mas/src/migration.rs +++ b/crates/syn2mas/src/migration.rs @@ -92,7 +92,7 @@ struct UsersMigrated { #[allow(clippy::implicit_hasher)] pub async fn migrate( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, clock: &dyn Clock, rng: &mut impl RngCore, @@ -179,7 +179,7 @@ pub async fn migrate( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_users( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, user_count_hint: usize, server_name: &str, rng: &mut impl RngCore, @@ -232,7 +232,7 @@ async fn migrate_users( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_threepids( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, rng: &mut impl RngCore, user_localparts_to_uuid: &HashMap, @@ -315,7 +315,7 @@ async fn migrate_threepids( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_external_ids( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, rng: &mut impl RngCore, user_localparts_to_uuid: &HashMap, @@ -391,7 +391,7 @@ async fn migrate_external_ids( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_devices( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, rng: &mut impl RngCore, user_localparts_to_uuid: &HashMap, @@ -483,7 +483,7 @@ async fn migrate_devices( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_unrefreshable_access_tokens( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, clock: &dyn Clock, rng: &mut impl RngCore, @@ -591,7 +591,7 @@ async fn migrate_unrefreshable_access_tokens( #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_refreshable_token_pairs( synapse: &mut SynapseReader<'_>, - mas: &mut MasWriter<'_>, + mas: &mut MasWriter, server_name: &str, clock: &dyn Clock, rng: &mut impl RngCore,