Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cli/src/commands/syn2mas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl Options {
.await?;
}

let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(&mut mas_connection)
let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(mas_connection)
.await
.context("failed to issue query to lock database")?
else {
Expand Down
4 changes: 1 addition & 3 deletions crates/syn2mas/src/mas_writer/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ pub enum Error {
/// - If we can't check whether syn2mas is already in progress on this database
/// or not.
#[tracing::instrument(skip_all)]
pub async fn mas_pre_migration_checks<'a>(
mas_connection: &mut LockedMasDatabase<'a>,
) -> Result<(), Error> {
pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) -> Result<(), Error> {
if is_syn2mas_in_progress(mas_connection.as_mut())
.await
.map_err(Error::UnableToCheckInProgress)?
Expand Down
14 changes: 7 additions & 7 deletions crates/syn2mas/src/mas_writer/locking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ static SYN2MAS_ADVISORY_LOCK: LazyLock<PgAdvisoryLock> =

/// A wrapper around a Postgres connection which holds a session-wide advisory
/// lock preventing concurrent access by other syn2mas instances.
pub struct LockedMasDatabase<'conn> {
inner: PgAdvisoryLockGuard<'static, &'conn mut PgConnection>,
pub struct LockedMasDatabase {
inner: PgAdvisoryLockGuard<'static, PgConnection>,
}

impl<'conn> LockedMasDatabase<'conn> {
impl LockedMasDatabase {
/// Attempts to lock the MAS database against concurrent access by other
/// syn2mas instances.
///
Expand All @@ -31,8 +31,8 @@ impl<'conn> LockedMasDatabase<'conn> {
///
/// Errors are returned for underlying database errors.
pub async fn try_new(
mas_connection: &'conn mut PgConnection,
) -> Result<Either<Self, &'conn mut PgConnection>, sqlx::Error> {
mas_connection: PgConnection,
) -> Result<Either<Self, PgConnection>, sqlx::Error> {
SYN2MAS_ADVISORY_LOCK
.try_acquire(mas_connection)
.await
Expand All @@ -48,12 +48,12 @@ impl<'conn> LockedMasDatabase<'conn> {
/// # Errors
///
/// Errors are returned for underlying database errors.
pub async fn unlock(self) -> Result<&'conn mut PgConnection, sqlx::Error> {
pub async fn unlock(self) -> Result<PgConnection, sqlx::Error> {
self.inner.release_now().await
}
}

impl AsMut<PgConnection> for LockedMasDatabase<'_> {
impl AsMut<PgConnection> for LockedMasDatabase {
fn as_mut(&mut self) -> &mut PgConnection {
self.inner.as_mut()
}
Expand Down
85 changes: 39 additions & 46 deletions crates/syn2mas/src/mas_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ impl WriterConnectionPool {
}
}

pub struct MasWriter<'c> {
conn: LockedMasDatabase<'c>,
pub struct MasWriter {
conn: LockedMasDatabase,
writer_pool: WriterConnectionPool,

indices_to_restore: Vec<IndexDescription>,
Expand Down Expand Up @@ -324,7 +324,7 @@ pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result<bool, Err
}
}

impl<'conn> MasWriter<'conn> {
impl MasWriter {
/// Creates a new MAS writer.
///
/// # Errors
Expand All @@ -335,7 +335,7 @@ impl<'conn> MasWriter<'conn> {
#[allow(clippy::missing_panics_doc)] // not real
#[tracing::instrument(skip_all)]
pub async fn new(
mut conn: LockedMasDatabase<'conn>,
mut conn: LockedMasDatabase,
mut writer_connections: Vec<PgConnection>,
) -> Result<Self, Error> {
// Given that we don't have any concurrent transactions here,
Expand Down Expand Up @@ -446,6 +446,7 @@ impl<'conn> MasWriter<'conn> {

Ok(Self {
conn,

writer_pool: WriterConnectionPool::new(writer_connections),
indices_to_restore,
constraints_to_restore,
Expand Down Expand Up @@ -488,7 +489,7 @@ impl<'conn> MasWriter<'conn> {
}

async fn restore_indices(
conn: &mut LockedMasDatabase<'_>,
conn: &mut LockedMasDatabase,
indices_to_restore: &[IndexDescription],
constraints_to_restore: &[ConstraintDescription],
) -> Result<(), Error> {
Expand All @@ -507,14 +508,15 @@ impl<'conn> MasWriter<'conn> {
}

/// Finish writing to the MAS database, flushing and committing all changes.
/// It returns the unlocked underlying connection.
///
/// # Errors
///
/// Errors are returned in the following conditions:
///
/// - If the database connection experiences an error.
#[tracing::instrument(skip_all)]
pub async fn finish(mut self) -> Result<(), Error> {
pub async fn finish(mut self) -> Result<PgConnection, Error> {
// Commit all writer transactions to the database.
self.writer_pool
.finish()
Expand Down Expand Up @@ -549,12 +551,13 @@ impl<'conn> MasWriter<'conn> {
.await
.into_database("ending MAS transaction")?;

self.conn
let conn = self
.conn
.unlock()
.await
.into_database("could not unlock MAS database")?;

Ok(())
Ok(conn)
}

/// Write a batch of users to the database.
Expand Down Expand Up @@ -1022,37 +1025,37 @@ const WRITE_BUFFER_BATCH_SIZE: usize = 4096;

/// A function that can accept and flush buffers from a `MasWriteBuffer`.
/// Intended uses are the methods on `MasWriter` such as `write_users`.
type WriteBufferFlusher<'conn, T> =
for<'a> fn(&'a mut MasWriter<'conn>, Vec<T>) -> BoxFuture<'a, Result<(), Error>>;
type WriteBufferFlusher<T> =
for<'a> fn(&'a mut MasWriter, Vec<T>) -> BoxFuture<'a, Result<(), Error>>;

/// A buffer for writing rows to the MAS database.
/// Generic over the type of rows.
///
/// # Panics
///
/// Panics if dropped before `finish()` has been called.
pub struct MasWriteBuffer<'conn, T> {
pub struct MasWriteBuffer<T> {
rows: Vec<T>,
flusher: WriteBufferFlusher<'conn, T>,
flusher: WriteBufferFlusher<T>,
finished: bool,
}

impl<'conn, T> MasWriteBuffer<'conn, T> {
pub fn new(flusher: WriteBufferFlusher<'conn, T>) -> Self {
impl<T> MasWriteBuffer<T> {
pub fn new(flusher: WriteBufferFlusher<T>) -> Self {
MasWriteBuffer {
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
flusher,
finished: false,
}
}

pub async fn finish(mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> {
pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
self.finished = true;
self.flush(writer).await?;
Ok(())
}

pub async fn flush(&mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> {
pub async fn flush(&mut self, writer: &mut MasWriter) -> Result<(), Error> {
if self.rows.is_empty() {
return Ok(());
}
Expand All @@ -1062,7 +1065,7 @@ impl<'conn, T> MasWriteBuffer<'conn, T> {
Ok(())
}

pub async fn write(&mut self, writer: &mut MasWriter<'conn>, row: T) -> Result<(), Error> {
pub async fn write(&mut self, writer: &mut MasWriter, row: T) -> Result<(), Error> {
self.rows.push(row);
if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE {
self.flush(writer).await?;
Expand All @@ -1071,7 +1074,7 @@ impl<'conn, T> MasWriteBuffer<'conn, T> {
}
}

impl<T> Drop for MasWriteBuffer<'_, T> {
impl<T> Drop for MasWriteBuffer<T> {
fn drop(&mut self) {
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
}
Expand Down Expand Up @@ -1180,10 +1183,8 @@ mod test {
/// Runs some code with a `MasWriter`.
///
/// The callback is responsible for `finish`ing the `MasWriter`.
async fn make_mas_writer<'conn>(
pool: &PgPool,
main_conn: &'conn mut PgConnection,
) -> MasWriter<'conn> {
async fn make_mas_writer(pool: &PgPool) -> MasWriter {
let main_conn = pool.acquire().await.unwrap().detach();
let mut writer_conns = Vec::new();
for _ in 0..2 {
writer_conns.push(
Expand All @@ -1205,8 +1206,7 @@ mod test {
/// Tests writing a single user, without a password.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_write_user(pool: PgPool) {
let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
let mut writer = make_mas_writer(&pool).await;

writer
.write_users(vec![MasNewUser {
Expand All @@ -1220,7 +1220,7 @@ mod test {
.await
.expect("failed to write user");

writer.finish().await.expect("failed to finish MasWriter");
let mut conn = writer.finish().await.expect("failed to finish MasWriter");

assert_db_snapshot!(&mut conn);
}
Expand All @@ -1230,8 +1230,7 @@ mod test {
async fn test_write_user_with_password(pool: PgPool) {
const USER_ID: Uuid = Uuid::from_u128(1u128);

let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
let mut writer = make_mas_writer(&pool).await;

writer
.write_users(vec![MasNewUser {
Expand All @@ -1254,16 +1253,15 @@ mod test {
.await
.expect("failed to write password");

writer.finish().await.expect("failed to finish MasWriter");
let mut conn = writer.finish().await.expect("failed to finish MasWriter");

assert_db_snapshot!(&mut conn);
}

/// Tests writing a single user, with an e-mail address associated.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_write_user_with_email(pool: PgPool) {
let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
let mut writer = make_mas_writer(&pool).await;

writer
.write_users(vec![MasNewUser {
Expand All @@ -1287,7 +1285,7 @@ mod test {
.await
.expect("failed to write e-mail");

writer.finish().await.expect("failed to finish MasWriter");
let mut conn = writer.finish().await.expect("failed to finish MasWriter");

assert_db_snapshot!(&mut conn);
}
Expand All @@ -1296,8 +1294,7 @@ mod test {
/// associated.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_write_user_with_unsupported_threepid(pool: PgPool) {
let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
let mut writer = make_mas_writer(&pool).await;

writer
.write_users(vec![MasNewUser {
Expand All @@ -1321,7 +1318,7 @@ mod test {
.await
.expect("failed to write phone number (unsupported threepid)");

writer.finish().await.expect("failed to finish MasWriter");
let mut conn = writer.finish().await.expect("failed to finish MasWriter");

assert_db_snapshot!(&mut conn);
}
Expand All @@ -1331,8 +1328,7 @@ mod test {
/// real migration, this is done by running a provider sync first.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
let mut writer = make_mas_writer(&pool).await;

writer
.write_users(vec![MasNewUser {
Expand All @@ -1357,16 +1353,15 @@ mod test {
.await
.expect("failed to write link");

writer.finish().await.expect("failed to finish MasWriter");
let mut conn = writer.finish().await.expect("failed to finish MasWriter");

assert_db_snapshot!(&mut conn);
}

/// Tests writing a single user, with a device (compat session).
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_write_user_with_device(pool: PgPool) {
let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
let mut writer = make_mas_writer(&pool).await;

writer
.write_users(vec![MasNewUser {
Expand Down Expand Up @@ -1395,16 +1390,15 @@ mod test {
.await
.expect("failed to write compat session");

writer.finish().await.expect("failed to finish MasWriter");
let mut conn = writer.finish().await.expect("failed to finish MasWriter");

assert_db_snapshot!(&mut conn);
}

/// Tests writing a single user, with a device and an access token.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_write_user_with_access_token(pool: PgPool) {
let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
let mut writer = make_mas_writer(&pool).await;

writer
.write_users(vec![MasNewUser {
Expand Down Expand Up @@ -1444,7 +1438,7 @@ mod test {
.await
.expect("failed to write access token");

writer.finish().await.expect("failed to finish MasWriter");
let mut conn = writer.finish().await.expect("failed to finish MasWriter");

assert_db_snapshot!(&mut conn);
}
Expand All @@ -1453,8 +1447,7 @@ mod test {
/// refresh token.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_write_user_with_refresh_token(pool: PgPool) {
let mut conn = pool.acquire().await.unwrap();
let mut writer = make_mas_writer(&pool, &mut conn).await;
let mut writer = make_mas_writer(&pool).await;

writer
.write_users(vec![MasNewUser {
Expand Down Expand Up @@ -1505,7 +1498,7 @@ mod test {
.await
.expect("failed to write refresh token");

writer.finish().await.expect("failed to finish MasWriter");
let mut conn = writer.finish().await.expect("failed to finish MasWriter");

assert_db_snapshot!(&mut conn);
}
Expand Down
Loading
Loading