Skip to content

Commit e46cd72

Browse files
committed
Make the MAS connection owned in the locked database struct
1 parent 2663a0f commit e46cd72

File tree

4 files changed

+39
-46
lines changed

4 files changed

+39
-46
lines changed

crates/cli/src/commands/syn2mas.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ impl Options {
142142
.await?;
143143
}
144144

145-
let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(&mut mas_connection)
145+
let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(mas_connection)
146146
.await
147147
.context("failed to issue query to lock database")?
148148
else {

crates/syn2mas/src/mas_writer/checks.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ pub enum Error {
4747
/// - If we can't check whether syn2mas is already in progress on this database
4848
/// or not.
4949
#[tracing::instrument(skip_all)]
50-
pub async fn mas_pre_migration_checks<'a>(
51-
mas_connection: &mut LockedMasDatabase<'a>,
52-
) -> Result<(), Error> {
50+
pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) -> Result<(), Error> {
5351
if is_syn2mas_in_progress(mas_connection.as_mut())
5452
.await
5553
.map_err(Error::UnableToCheckInProgress)?

crates/syn2mas/src/mas_writer/locking.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ static SYN2MAS_ADVISORY_LOCK: LazyLock<PgAdvisoryLock> =
1515

1616
/// A wrapper around a Postgres connection which holds a session-wide advisory
1717
/// lock preventing concurrent access by other syn2mas instances.
18-
pub struct LockedMasDatabase<'conn> {
19-
inner: PgAdvisoryLockGuard<'static, &'conn mut PgConnection>,
18+
pub struct LockedMasDatabase {
19+
inner: PgAdvisoryLockGuard<'static, PgConnection>,
2020
}
2121

22-
impl<'conn> LockedMasDatabase<'conn> {
22+
impl LockedMasDatabase {
2323
/// Attempts to lock the MAS database against concurrent access by other
2424
/// syn2mas instances.
2525
///
@@ -31,8 +31,8 @@ impl<'conn> LockedMasDatabase<'conn> {
3131
///
3232
/// Errors are returned for underlying database errors.
3333
pub async fn try_new(
34-
mas_connection: &'conn mut PgConnection,
35-
) -> Result<Either<Self, &'conn mut PgConnection>, sqlx::Error> {
34+
mas_connection: PgConnection,
35+
) -> Result<Either<Self, PgConnection>, sqlx::Error> {
3636
SYN2MAS_ADVISORY_LOCK
3737
.try_acquire(mas_connection)
3838
.await
@@ -48,12 +48,12 @@ impl<'conn> LockedMasDatabase<'conn> {
4848
/// # Errors
4949
///
5050
/// Errors are returned for underlying database errors.
51-
pub async fn unlock(self) -> Result<&'conn mut PgConnection, sqlx::Error> {
51+
pub async fn unlock(self) -> Result<PgConnection, sqlx::Error> {
5252
self.inner.release_now().await
5353
}
5454
}
5555

56-
impl AsMut<PgConnection> for LockedMasDatabase<'_> {
56+
impl AsMut<PgConnection> for LockedMasDatabase {
5757
fn as_mut(&mut self) -> &mut PgConnection {
5858
self.inner.as_mut()
5959
}

crates/syn2mas/src/mas_writer/mod.rs

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ impl WriterConnectionPool {
186186
}
187187

188188
pub struct MasWriter<'c> {
189-
conn: LockedMasDatabase<'c>,
189+
conn: LockedMasDatabase,
190+
// Temporary phantom data, so that we don't remove the lifetime parameter yet
191+
phantom: std::marker::PhantomData<&'c ()>,
190192
writer_pool: WriterConnectionPool,
191193

192194
indices_to_restore: Vec<IndexDescription>,
@@ -324,7 +326,7 @@ pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result<bool, Err
324326
}
325327
}
326328

327-
impl<'conn> MasWriter<'conn> {
329+
impl MasWriter<'_> {
328330
/// Creates a new MAS writer.
329331
///
330332
/// # Errors
@@ -335,7 +337,7 @@ impl<'conn> MasWriter<'conn> {
335337
#[allow(clippy::missing_panics_doc)] // not real
336338
#[tracing::instrument(skip_all)]
337339
pub async fn new(
338-
mut conn: LockedMasDatabase<'conn>,
340+
mut conn: LockedMasDatabase,
339341
mut writer_connections: Vec<PgConnection>,
340342
) -> Result<Self, Error> {
341343
// Given that we don't have any concurrent transactions here,
@@ -446,6 +448,7 @@ impl<'conn> MasWriter<'conn> {
446448

447449
Ok(Self {
448450
conn,
451+
phantom: std::marker::PhantomData,
449452
writer_pool: WriterConnectionPool::new(writer_connections),
450453
indices_to_restore,
451454
constraints_to_restore,
@@ -488,7 +491,7 @@ impl<'conn> MasWriter<'conn> {
488491
}
489492

490493
async fn restore_indices(
491-
conn: &mut LockedMasDatabase<'_>,
494+
conn: &mut LockedMasDatabase,
492495
indices_to_restore: &[IndexDescription],
493496
constraints_to_restore: &[ConstraintDescription],
494497
) -> Result<(), Error> {
@@ -507,14 +510,15 @@ impl<'conn> MasWriter<'conn> {
507510
}
508511

509512
/// Finish writing to the MAS database, flushing and committing all changes.
513+
/// It returns the unlocked underlying connection.
510514
///
511515
/// # Errors
512516
///
513517
/// Errors are returned in the following conditions:
514518
///
515519
/// - If the database connection experiences an error.
516520
#[tracing::instrument(skip_all)]
517-
pub async fn finish(mut self) -> Result<(), Error> {
521+
pub async fn finish(mut self) -> Result<PgConnection, Error> {
518522
// Commit all writer transactions to the database.
519523
self.writer_pool
520524
.finish()
@@ -549,12 +553,13 @@ impl<'conn> MasWriter<'conn> {
549553
.await
550554
.into_database("ending MAS transaction")?;
551555

552-
self.conn
556+
let conn = self
557+
.conn
553558
.unlock()
554559
.await
555560
.into_database("could not unlock MAS database")?;
556561

557-
Ok(())
562+
Ok(conn)
558563
}
559564

560565
/// Write a batch of users to the database.
@@ -1180,10 +1185,8 @@ mod test {
11801185
/// Runs some code with a `MasWriter`.
11811186
///
11821187
/// The callback is responsible for `finish`ing the `MasWriter`.
1183-
async fn make_mas_writer<'conn>(
1184-
pool: &PgPool,
1185-
main_conn: &'conn mut PgConnection,
1186-
) -> MasWriter<'conn> {
1188+
async fn make_mas_writer(pool: &PgPool) -> MasWriter<'static> {
1189+
let main_conn = pool.acquire().await.unwrap().detach();
11871190
let mut writer_conns = Vec::new();
11881191
for _ in 0..2 {
11891192
writer_conns.push(
@@ -1205,8 +1208,7 @@ mod test {
12051208
/// Tests writing a single user, without a password.
12061209
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
12071210
async fn test_write_user(pool: PgPool) {
1208-
let mut conn = pool.acquire().await.unwrap();
1209-
let mut writer = make_mas_writer(&pool, &mut conn).await;
1211+
let mut writer = make_mas_writer(&pool).await;
12101212

12111213
writer
12121214
.write_users(vec![MasNewUser {
@@ -1220,7 +1222,7 @@ mod test {
12201222
.await
12211223
.expect("failed to write user");
12221224

1223-
writer.finish().await.expect("failed to finish MasWriter");
1225+
let mut conn = writer.finish().await.expect("failed to finish MasWriter");
12241226

12251227
assert_db_snapshot!(&mut conn);
12261228
}
@@ -1230,8 +1232,7 @@ mod test {
12301232
async fn test_write_user_with_password(pool: PgPool) {
12311233
const USER_ID: Uuid = Uuid::from_u128(1u128);
12321234

1233-
let mut conn = pool.acquire().await.unwrap();
1234-
let mut writer = make_mas_writer(&pool, &mut conn).await;
1235+
let mut writer = make_mas_writer(&pool).await;
12351236

12361237
writer
12371238
.write_users(vec![MasNewUser {
@@ -1254,16 +1255,15 @@ mod test {
12541255
.await
12551256
.expect("failed to write password");
12561257

1257-
writer.finish().await.expect("failed to finish MasWriter");
1258+
let mut conn = writer.finish().await.expect("failed to finish MasWriter");
12581259

12591260
assert_db_snapshot!(&mut conn);
12601261
}
12611262

12621263
/// Tests writing a single user, with an e-mail address associated.
12631264
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
12641265
async fn test_write_user_with_email(pool: PgPool) {
1265-
let mut conn = pool.acquire().await.unwrap();
1266-
let mut writer = make_mas_writer(&pool, &mut conn).await;
1266+
let mut writer = make_mas_writer(&pool).await;
12671267

12681268
writer
12691269
.write_users(vec![MasNewUser {
@@ -1287,7 +1287,7 @@ mod test {
12871287
.await
12881288
.expect("failed to write e-mail");
12891289

1290-
writer.finish().await.expect("failed to finish MasWriter");
1290+
let mut conn = writer.finish().await.expect("failed to finish MasWriter");
12911291

12921292
assert_db_snapshot!(&mut conn);
12931293
}
@@ -1296,8 +1296,7 @@ mod test {
12961296
/// associated.
12971297
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
12981298
async fn test_write_user_with_unsupported_threepid(pool: PgPool) {
1299-
let mut conn = pool.acquire().await.unwrap();
1300-
let mut writer = make_mas_writer(&pool, &mut conn).await;
1299+
let mut writer = make_mas_writer(&pool).await;
13011300

13021301
writer
13031302
.write_users(vec![MasNewUser {
@@ -1321,7 +1320,7 @@ mod test {
13211320
.await
13221321
.expect("failed to write phone number (unsupported threepid)");
13231322

1324-
writer.finish().await.expect("failed to finish MasWriter");
1323+
let mut conn = writer.finish().await.expect("failed to finish MasWriter");
13251324

13261325
assert_db_snapshot!(&mut conn);
13271326
}
@@ -1331,8 +1330,7 @@ mod test {
13311330
/// real migration, this is done by running a provider sync first.
13321331
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
13331332
async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
1334-
let mut conn = pool.acquire().await.unwrap();
1335-
let mut writer = make_mas_writer(&pool, &mut conn).await;
1333+
let mut writer = make_mas_writer(&pool).await;
13361334

13371335
writer
13381336
.write_users(vec![MasNewUser {
@@ -1357,16 +1355,15 @@ mod test {
13571355
.await
13581356
.expect("failed to write link");
13591357

1360-
writer.finish().await.expect("failed to finish MasWriter");
1358+
let mut conn = writer.finish().await.expect("failed to finish MasWriter");
13611359

13621360
assert_db_snapshot!(&mut conn);
13631361
}
13641362

13651363
/// Tests writing a single user, with a device (compat session).
13661364
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
13671365
async fn test_write_user_with_device(pool: PgPool) {
1368-
let mut conn = pool.acquire().await.unwrap();
1369-
let mut writer = make_mas_writer(&pool, &mut conn).await;
1366+
let mut writer = make_mas_writer(&pool).await;
13701367

13711368
writer
13721369
.write_users(vec![MasNewUser {
@@ -1395,16 +1392,15 @@ mod test {
13951392
.await
13961393
.expect("failed to write compat session");
13971394

1398-
writer.finish().await.expect("failed to finish MasWriter");
1395+
let mut conn = writer.finish().await.expect("failed to finish MasWriter");
13991396

14001397
assert_db_snapshot!(&mut conn);
14011398
}
14021399

14031400
/// Tests writing a single user, with a device and an access token.
14041401
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
14051402
async fn test_write_user_with_access_token(pool: PgPool) {
1406-
let mut conn = pool.acquire().await.unwrap();
1407-
let mut writer = make_mas_writer(&pool, &mut conn).await;
1403+
let mut writer = make_mas_writer(&pool).await;
14081404

14091405
writer
14101406
.write_users(vec![MasNewUser {
@@ -1444,7 +1440,7 @@ mod test {
14441440
.await
14451441
.expect("failed to write access token");
14461442

1447-
writer.finish().await.expect("failed to finish MasWriter");
1443+
let mut conn = writer.finish().await.expect("failed to finish MasWriter");
14481444

14491445
assert_db_snapshot!(&mut conn);
14501446
}
@@ -1453,8 +1449,7 @@ mod test {
14531449
/// refresh token.
14541450
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
14551451
async fn test_write_user_with_refresh_token(pool: PgPool) {
1456-
let mut conn = pool.acquire().await.unwrap();
1457-
let mut writer = make_mas_writer(&pool, &mut conn).await;
1452+
let mut writer = make_mas_writer(&pool).await;
14581453

14591454
writer
14601455
.write_users(vec![MasNewUser {
@@ -1505,7 +1500,7 @@ mod test {
15051500
.await
15061501
.expect("failed to write refresh token");
15071502

1508-
writer.finish().await.expect("failed to finish MasWriter");
1503+
let mut conn = writer.finish().await.expect("failed to finish MasWriter");
15091504

15101505
assert_db_snapshot!(&mut conn);
15111506
}

0 commit comments

Comments
 (0)