Skip to content

Commit 487c53c

Browse files
committed
Replace panic on unfinished buffers with an error
This doesn't get clobbered by the progress bar
1 parent 669ec70 commit 487c53c

File tree

2 files changed

+70
-25
lines changed

2 files changed

+70
-25
lines changed

crates/syn2mas/src/mas_writer/mod.rs

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
//!
88
//! This module is responsible for writing new records to MAS' database.
99
10-
use std::{fmt::Display, net::IpAddr};
10+
use std::{
11+
fmt::Display,
12+
net::IpAddr,
13+
sync::{
14+
Arc,
15+
atomic::{AtomicU32, Ordering},
16+
},
17+
};
1118

1219
use chrono::{DateTime, Utc};
1320
use futures_util::{FutureExt, TryStreamExt, future::BoxFuture};
@@ -44,6 +51,9 @@ pub enum Error {
4451
#[error("inconsistent database: {0}")]
4552
Inconsistent(String),
4653

54+
#[error("bug in syn2mas: write buffers not finished")]
55+
WriteBuffersNotFinished,
56+
4757
#[error("{0}")]
4858
Multiple(MultipleErrors),
4959
}
@@ -188,12 +198,52 @@ impl WriterConnectionPool {
188198
}
189199
}
190200

201+
/// Small utility to make sure `finish()` is called on all write buffers
202+
/// before committing to the database.
203+
#[derive(Default)]
204+
struct FinishChecker {
205+
counter: Arc<AtomicU32>,
206+
}
207+
208+
struct FinishCheckerHandle {
209+
counter: Arc<AtomicU32>,
210+
}
211+
212+
impl FinishChecker {
213+
/// Acquire a new handle, for a task that should declare when it has
214+
/// finished.
215+
pub fn handle(&self) -> FinishCheckerHandle {
216+
self.counter.fetch_add(1, Ordering::SeqCst);
217+
FinishCheckerHandle {
218+
counter: Arc::clone(&self.counter),
219+
}
220+
}
221+
222+
/// Check that all handles have been declared as finished.
223+
pub fn check_all_finished(self) -> Result<(), Error> {
224+
if self.counter.load(Ordering::SeqCst) == 0 {
225+
Ok(())
226+
} else {
227+
Err(Error::WriteBuffersNotFinished)
228+
}
229+
}
230+
}
231+
232+
impl FinishCheckerHandle {
233+
/// Declare that the task this handle represents has been finished.
234+
pub fn declare_finished(self) {
235+
self.counter.fetch_sub(1, Ordering::SeqCst);
236+
}
237+
}
238+
191239
pub struct MasWriter {
192240
conn: LockedMasDatabase,
193241
writer_pool: WriterConnectionPool,
194242

195243
indices_to_restore: Vec<IndexDescription>,
196244
constraints_to_restore: Vec<ConstraintDescription>,
245+
246+
write_buffer_finish_checker: FinishChecker,
197247
}
198248

199249
pub struct MasNewUser {
@@ -453,6 +503,7 @@ impl MasWriter {
453503
writer_pool: WriterConnectionPool::new(writer_connections),
454504
indices_to_restore,
455505
constraints_to_restore,
506+
write_buffer_finish_checker: FinishChecker::default(),
456507
})
457508
}
458509

@@ -520,6 +571,8 @@ impl MasWriter {
520571
/// - If the database connection experiences an error.
521572
#[tracing::instrument(skip_all)]
522573
pub async fn finish(mut self) -> Result<PgConnection, Error> {
574+
self.write_buffer_finish_checker.check_all_finished()?;
575+
523576
// Commit all writer transactions to the database.
524577
self.writer_pool
525578
.finish()
@@ -1033,28 +1086,24 @@ type WriteBufferFlusher<T> =
10331086

10341087
/// A buffer for writing rows to the MAS database.
10351088
/// Generic over the type of rows.
1036-
///
1037-
/// # Panics
1038-
///
1039-
/// Panics if dropped before `finish()` has been called.
10401089
pub struct MasWriteBuffer<T> {
10411090
rows: Vec<T>,
10421091
flusher: WriteBufferFlusher<T>,
1043-
finished: bool,
1092+
finish_checker_handle: FinishCheckerHandle,
10441093
}
10451094

10461095
impl<T> MasWriteBuffer<T> {
1047-
pub fn new(flusher: WriteBufferFlusher<T>) -> Self {
1096+
pub fn new(writer: &MasWriter, flusher: WriteBufferFlusher<T>) -> Self {
10481097
MasWriteBuffer {
10491098
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
10501099
flusher,
1051-
finished: false,
1100+
finish_checker_handle: writer.write_buffer_finish_checker.handle(),
10521101
}
10531102
}
10541103

10551104
pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
1056-
self.finished = true;
10571105
self.flush(writer).await?;
1106+
self.finish_checker_handle.declare_finished();
10581107
Ok(())
10591108
}
10601109

@@ -1077,12 +1126,6 @@ impl<T> MasWriteBuffer<T> {
10771126
}
10781127
}
10791128

1080-
impl<T> Drop for MasWriteBuffer<T> {
1081-
fn drop(&mut self) {
1082-
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
1083-
}
1084-
}
1085-
10861129
#[cfg(test)]
10871130
mod test {
10881131
use std::collections::{BTreeMap, BTreeSet};

crates/syn2mas/src/migration.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ async fn migrate_users(
175175
mut state: MigrationState,
176176
rng: &mut impl RngCore,
177177
) -> Result<(MasWriter, MigrationState), Error> {
178-
let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users);
179-
let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords);
178+
let mut user_buffer = MasWriteBuffer::new(&mas, MasWriter::write_users);
179+
let mut password_buffer = MasWriteBuffer::new(&mas, MasWriter::write_passwords);
180180
let mut users_stream = pin!(synapse.read_users());
181181

182182
while let Some(user_res) = users_stream.next().await {
@@ -262,8 +262,8 @@ async fn migrate_threepids(
262262
rng: &mut impl RngCore,
263263
state: MigrationState,
264264
) -> Result<(MasWriter, MigrationState), Error> {
265-
let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids);
266-
let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids);
265+
let mut email_buffer = MasWriteBuffer::new(&mas, MasWriter::write_email_threepids);
266+
let mut unsupported_buffer = MasWriteBuffer::new(&mas, MasWriter::write_unsupported_threepids);
267267
let mut users_stream = pin!(synapse.read_threepids());
268268

269269
while let Some(threepid_res) = users_stream.next().await {
@@ -345,7 +345,7 @@ async fn migrate_external_ids(
345345
rng: &mut impl RngCore,
346346
state: MigrationState,
347347
) -> Result<(MasWriter, MigrationState), Error> {
348-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links);
348+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_upstream_oauth_links);
349349
let mut extids_stream = pin!(synapse.read_user_external_ids());
350350

351351
while let Some(extid_res) = extids_stream.next().await {
@@ -421,7 +421,7 @@ async fn migrate_devices(
421421
mut state: MigrationState,
422422
) -> Result<(MasWriter, MigrationState), Error> {
423423
let mut devices_stream = pin!(synapse.read_devices());
424-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
424+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
425425

426426
while let Some(device_res) = devices_stream.next().await {
427427
let SynapseDevice {
@@ -521,8 +521,9 @@ async fn migrate_unrefreshable_access_tokens(
521521
mut state: MigrationState,
522522
) -> Result<(MasWriter, MigrationState), Error> {
523523
let mut token_stream = pin!(synapse.read_unrefreshable_access_tokens());
524-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
525-
let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
524+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
525+
let mut deviceless_session_write_buffer =
526+
MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
526527

527528
while let Some(token_res) = token_stream.next().await {
528529
let SynapseAccessToken {
@@ -635,9 +636,10 @@ async fn migrate_refreshable_token_pairs(
635636
mut state: MigrationState,
636637
) -> Result<(MasWriter, MigrationState), Error> {
637638
let mut token_stream = pin!(synapse.read_refreshable_token_pairs());
638-
let mut access_token_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
639+
let mut access_token_write_buffer =
640+
MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
639641
let mut refresh_token_write_buffer =
640-
MasWriteBuffer::new(MasWriter::write_compat_refresh_tokens);
642+
MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens);
641643

642644
while let Some(token_res) = token_stream.next().await {
643645
let SynapseRefreshableTokenPair {

0 commit comments

Comments
 (0)