Skip to content

Commit 8d738ed

Browse files
reivilibresandhose
authored andcommitted
Replace panic on unfinished buffers with an error
This doesn't get clobbered by the progress bar
1 parent 31fae19 commit 8d738ed

File tree

2 files changed

+70
-25
lines changed

2 files changed

+70
-25
lines changed

crates/syn2mas/src/mas_writer/mod.rs

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
//!
88
//! This module is responsible for writing new records to MAS' database.
99
10-
use std::{fmt::Display, net::IpAddr};
10+
use std::{
11+
fmt::Display,
12+
net::IpAddr,
13+
sync::{
14+
Arc,
15+
atomic::{AtomicU32, Ordering},
16+
},
17+
};
1118

1219
use chrono::{DateTime, Utc};
1320
use futures_util::{FutureExt, TryStreamExt, future::BoxFuture};
@@ -44,6 +51,9 @@ pub enum Error {
4451
#[error("inconsistent database: {0}")]
4552
Inconsistent(String),
4653

54+
#[error("bug in syn2mas: write buffers not finished")]
55+
WriteBuffersNotFinished,
56+
4757
#[error("{0}")]
4858
Multiple(MultipleErrors),
4959
}
@@ -188,12 +198,52 @@ impl WriterConnectionPool {
188198
}
189199
}
190200

201+
/// Small utility to make sure `finish()` is called on all write buffers
202+
/// before committing to the database.
203+
#[derive(Default)]
204+
struct FinishChecker {
205+
counter: Arc<AtomicU32>,
206+
}
207+
208+
struct FinishCheckerHandle {
209+
counter: Arc<AtomicU32>,
210+
}
211+
212+
impl FinishChecker {
213+
/// Acquire a new handle, for a task that should declare when it has
214+
/// finished.
215+
pub fn handle(&self) -> FinishCheckerHandle {
216+
self.counter.fetch_add(1, Ordering::SeqCst);
217+
FinishCheckerHandle {
218+
counter: Arc::clone(&self.counter),
219+
}
220+
}
221+
222+
/// Check that all handles have been declared as finished.
223+
pub fn check_all_finished(self) -> Result<(), Error> {
224+
if self.counter.load(Ordering::SeqCst) == 0 {
225+
Ok(())
226+
} else {
227+
Err(Error::WriteBuffersNotFinished)
228+
}
229+
}
230+
}
231+
232+
impl FinishCheckerHandle {
233+
/// Declare that the task this handle represents has been finished.
234+
pub fn declare_finished(self) {
235+
self.counter.fetch_sub(1, Ordering::SeqCst);
236+
}
237+
}
238+
191239
pub struct MasWriter {
192240
conn: LockedMasDatabase,
193241
writer_pool: WriterConnectionPool,
194242

195243
indices_to_restore: Vec<IndexDescription>,
196244
constraints_to_restore: Vec<ConstraintDescription>,
245+
246+
write_buffer_finish_checker: FinishChecker,
197247
}
198248

199249
pub struct MasNewUser {
@@ -453,6 +503,7 @@ impl MasWriter {
453503
writer_pool: WriterConnectionPool::new(writer_connections),
454504
indices_to_restore,
455505
constraints_to_restore,
506+
write_buffer_finish_checker: FinishChecker::default(),
456507
})
457508
}
458509

@@ -520,6 +571,8 @@ impl MasWriter {
520571
/// - If the database connection experiences an error.
521572
#[tracing::instrument(skip_all)]
522573
pub async fn finish(mut self) -> Result<PgConnection, Error> {
574+
self.write_buffer_finish_checker.check_all_finished()?;
575+
523576
// Commit all writer transactions to the database.
524577
self.writer_pool
525578
.finish()
@@ -1033,28 +1086,24 @@ type WriteBufferFlusher<T> =
10331086

10341087
/// A buffer for writing rows to the MAS database.
10351088
/// Generic over the type of rows.
1036-
///
1037-
/// # Panics
1038-
///
1039-
/// Panics if dropped before `finish()` has been called.
10401089
pub struct MasWriteBuffer<T> {
10411090
rows: Vec<T>,
10421091
flusher: WriteBufferFlusher<T>,
1043-
finished: bool,
1092+
finish_checker_handle: FinishCheckerHandle,
10441093
}
10451094

10461095
impl<T> MasWriteBuffer<T> {
1047-
pub fn new(flusher: WriteBufferFlusher<T>) -> Self {
1096+
pub fn new(writer: &MasWriter, flusher: WriteBufferFlusher<T>) -> Self {
10481097
MasWriteBuffer {
10491098
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
10501099
flusher,
1051-
finished: false,
1100+
finish_checker_handle: writer.write_buffer_finish_checker.handle(),
10521101
}
10531102
}
10541103

10551104
pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
1056-
self.finished = true;
10571105
self.flush(writer).await?;
1106+
self.finish_checker_handle.declare_finished();
10581107
Ok(())
10591108
}
10601109

@@ -1077,12 +1126,6 @@ impl<T> MasWriteBuffer<T> {
10771126
}
10781127
}
10791128

1080-
impl<T> Drop for MasWriteBuffer<T> {
1081-
fn drop(&mut self) {
1082-
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
1083-
}
1084-
}
1085-
10861129
#[cfg(test)]
10871130
mod test {
10881131
use std::collections::{BTreeMap, BTreeSet};

crates/syn2mas/src/migration.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ async fn migrate_users(
232232
span.pb_set_style(&ProgressStyle::default_bar());
233233
span.pb_set_length(count_hint as u64);
234234

235-
let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users);
236-
let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords);
235+
let mut user_buffer = MasWriteBuffer::new(&mas, MasWriter::write_users);
236+
let mut password_buffer = MasWriteBuffer::new(&mas, MasWriter::write_passwords);
237237
let mut users_stream = pin!(synapse.read_users());
238238

239239
while let Some(user_res) = users_stream.next().await {
@@ -312,8 +312,8 @@ async fn migrate_threepids(
312312
span.pb_set_style(&ProgressStyle::default_bar());
313313
span.pb_set_length(count_hint as u64);
314314

315-
let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids);
316-
let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids);
315+
let mut email_buffer = MasWriteBuffer::new(&mas, MasWriter::write_email_threepids);
316+
let mut unsupported_buffer = MasWriteBuffer::new(&mas, MasWriter::write_unsupported_threepids);
317317
let mut users_stream = pin!(synapse.read_threepids());
318318

319319
while let Some(threepid_res) = users_stream.next().await {
@@ -402,7 +402,7 @@ async fn migrate_external_ids(
402402
span.pb_set_style(&ProgressStyle::default_bar());
403403
span.pb_set_length(count_hint as u64);
404404

405-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links);
405+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_upstream_oauth_links);
406406
let mut extids_stream = pin!(synapse.read_user_external_ids());
407407

408408
while let Some(extid_res) = extids_stream.next().await {
@@ -486,7 +486,7 @@ async fn migrate_devices(
486486
span.pb_set_length(count_hint as u64);
487487

488488
let mut devices_stream = pin!(synapse.read_devices());
489-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
489+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
490490

491491
while let Some(device_res) = devices_stream.next().await {
492492
span.pb_inc(1);
@@ -590,8 +590,9 @@ async fn migrate_unrefreshable_access_tokens(
590590
span.pb_set_length(count_hint as u64);
591591

592592
let mut token_stream = pin!(synapse.read_unrefreshable_access_tokens());
593-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
594-
let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
593+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
594+
let mut deviceless_session_write_buffer =
595+
MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
595596

596597
while let Some(token_res) = token_stream.next().await {
597598
span.pb_inc(1);
@@ -708,9 +709,10 @@ async fn migrate_refreshable_token_pairs(
708709
span.pb_set_length(count_hint as u64);
709710

710711
let mut token_stream = pin!(synapse.read_refreshable_token_pairs());
711-
let mut access_token_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
712+
let mut access_token_write_buffer =
713+
MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
712714
let mut refresh_token_write_buffer =
713-
MasWriteBuffer::new(MasWriter::write_compat_refresh_tokens);
715+
MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens);
714716

715717
while let Some(token_res) = token_stream.next().await {
716718
span.pb_inc(1);

0 commit comments

Comments
 (0)