Skip to content

Commit 4cc2d34

Browse files
reivilibresandhose
authored andcommitted
Replace panic on unfinished buffers with an error
This doesn't get clobbered by the progress bar
1 parent 937fd05 commit 4cc2d34

File tree

2 files changed

+70
-25
lines changed

2 files changed

+70
-25
lines changed

crates/syn2mas/src/mas_writer/mod.rs

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
//!
88
//! This module is responsible for writing new records to MAS' database.
99
10-
use std::{fmt::Display, net::IpAddr};
10+
use std::{
11+
fmt::Display,
12+
net::IpAddr,
13+
sync::{
14+
atomic::{AtomicU32, Ordering},
15+
Arc,
16+
},
17+
};
1118

1219
use chrono::{DateTime, Utc};
1320
use futures_util::{future::BoxFuture, FutureExt, TryStreamExt};
@@ -44,6 +51,9 @@ pub enum Error {
4451
#[error("inconsistent database: {0}")]
4552
Inconsistent(String),
4653

54+
#[error("bug in syn2mas: write buffers not finished")]
55+
WriteBuffersNotFinished,
56+
4757
#[error("{0}")]
4858
Multiple(MultipleErrors),
4959
}
@@ -185,12 +195,52 @@ impl WriterConnectionPool {
185195
}
186196
}
187197

198+
/// Small utility to make sure `finish()` is called on all write buffers
199+
/// before committing to the database.
200+
#[derive(Default)]
201+
struct FinishChecker {
202+
counter: Arc<AtomicU32>,
203+
}
204+
205+
struct FinishCheckerHandle {
206+
counter: Arc<AtomicU32>,
207+
}
208+
209+
impl FinishChecker {
210+
/// Acquire a new handle, for a task that should declare when it has
211+
/// finished.
212+
pub fn handle(&self) -> FinishCheckerHandle {
213+
self.counter.fetch_add(1, Ordering::SeqCst);
214+
FinishCheckerHandle {
215+
counter: Arc::clone(&self.counter),
216+
}
217+
}
218+
219+
/// Check that all handles have been declared as finished.
220+
pub fn check_all_finished(self) -> Result<(), Error> {
221+
if self.counter.load(Ordering::SeqCst) == 0 {
222+
Ok(())
223+
} else {
224+
Err(Error::WriteBuffersNotFinished)
225+
}
226+
}
227+
}
228+
229+
impl FinishCheckerHandle {
230+
/// Declare that the task this handle represents has been finished.
231+
pub fn declare_finished(self) {
232+
self.counter.fetch_sub(1, Ordering::SeqCst);
233+
}
234+
}
235+
188236
pub struct MasWriter {
189237
conn: LockedMasDatabase,
190238
writer_pool: WriterConnectionPool,
191239

192240
indices_to_restore: Vec<IndexDescription>,
193241
constraints_to_restore: Vec<ConstraintDescription>,
242+
243+
write_buffer_finish_checker: FinishChecker,
194244
}
195245

196246
pub struct MasNewUser {
@@ -450,6 +500,7 @@ impl MasWriter {
450500
writer_pool: WriterConnectionPool::new(writer_connections),
451501
indices_to_restore,
452502
constraints_to_restore,
503+
write_buffer_finish_checker: FinishChecker::default(),
453504
})
454505
}
455506

@@ -517,6 +568,8 @@ impl MasWriter {
517568
/// - If the database connection experiences an error.
518569
#[tracing::instrument(skip_all)]
519570
pub async fn finish(mut self) -> Result<PgConnection, Error> {
571+
self.write_buffer_finish_checker.check_all_finished()?;
572+
520573
// Commit all writer transactions to the database.
521574
self.writer_pool
522575
.finish()
@@ -1030,28 +1083,24 @@ type WriteBufferFlusher<T> =
10301083

10311084
/// A buffer for writing rows to the MAS database.
10321085
/// Generic over the type of rows.
1033-
///
1034-
/// # Panics
1035-
///
1036-
/// Panics if dropped before `finish()` has been called.
10371086
pub struct MasWriteBuffer<T> {
10381087
rows: Vec<T>,
10391088
flusher: WriteBufferFlusher<T>,
1040-
finished: bool,
1089+
finish_checker_handle: FinishCheckerHandle,
10411090
}
10421091

10431092
impl<T> MasWriteBuffer<T> {
1044-
pub fn new(flusher: WriteBufferFlusher<T>) -> Self {
1093+
pub fn new(writer: &MasWriter, flusher: WriteBufferFlusher<T>) -> Self {
10451094
MasWriteBuffer {
10461095
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
10471096
flusher,
1048-
finished: false,
1097+
finish_checker_handle: writer.write_buffer_finish_checker.handle(),
10491098
}
10501099
}
10511100

10521101
pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
1053-
self.finished = true;
10541102
self.flush(writer).await?;
1103+
self.finish_checker_handle.declare_finished();
10551104
Ok(())
10561105
}
10571106

@@ -1074,12 +1123,6 @@ impl<T> MasWriteBuffer<T> {
10741123
}
10751124
}
10761125

1077-
impl<T> Drop for MasWriteBuffer<T> {
1078-
fn drop(&mut self) {
1079-
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
1080-
}
1081-
}
1082-
10831126
#[cfg(test)]
10841127
mod test {
10851128
use std::collections::{BTreeMap, BTreeSet};

crates/syn2mas/src/migration.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ async fn migrate_users(
225225
span.pb_set_style(&ProgressStyle::default_bar());
226226
span.pb_set_length(count_hint as u64);
227227

228-
let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users);
229-
let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords);
228+
let mut user_buffer = MasWriteBuffer::new(&mas, MasWriter::write_users);
229+
let mut password_buffer = MasWriteBuffer::new(&mas, MasWriter::write_passwords);
230230
let mut users_stream = pin!(synapse.read_users());
231231

232232
while let Some(user_res) = users_stream.next().await {
@@ -291,8 +291,8 @@ async fn migrate_threepids(
291291
span.pb_set_style(&ProgressStyle::default_bar());
292292
span.pb_set_length(count_hint as u64);
293293

294-
let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids);
295-
let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids);
294+
let mut email_buffer = MasWriteBuffer::new(&mas, MasWriter::write_email_threepids);
295+
let mut unsupported_buffer = MasWriteBuffer::new(&mas, MasWriter::write_unsupported_threepids);
296296
let mut users_stream = pin!(synapse.read_threepids());
297297

298298
while let Some(threepid_res) = users_stream.next().await {
@@ -380,7 +380,7 @@ async fn migrate_external_ids(
380380
span.pb_set_style(&ProgressStyle::default_bar());
381381
span.pb_set_length(count_hint as u64);
382382

383-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links);
383+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_upstream_oauth_links);
384384
let mut extids_stream = pin!(synapse.read_user_external_ids());
385385

386386
while let Some(extid_res) = extids_stream.next().await {
@@ -463,7 +463,7 @@ async fn migrate_devices(
463463
span.pb_set_length(count_hint as u64);
464464

465465
let mut devices_stream = pin!(synapse.read_devices());
466-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
466+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
467467

468468
while let Some(device_res) = devices_stream.next().await {
469469
span.pb_inc(1);
@@ -565,8 +565,9 @@ async fn migrate_unrefreshable_access_tokens(
565565
span.pb_set_length(count_hint as u64);
566566

567567
let mut token_stream = pin!(synapse.read_unrefreshable_access_tokens());
568-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
569-
let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
568+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
569+
let mut deviceless_session_write_buffer =
570+
MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
570571

571572
while let Some(token_res) = token_stream.next().await {
572573
span.pb_inc(1);
@@ -683,9 +684,10 @@ async fn migrate_refreshable_token_pairs(
683684
span.pb_set_length(count_hint as u64);
684685

685686
let mut token_stream = pin!(synapse.read_refreshable_token_pairs());
686-
let mut access_token_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
687+
let mut access_token_write_buffer =
688+
MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
687689
let mut refresh_token_write_buffer =
688-
MasWriteBuffer::new(MasWriter::write_compat_refresh_tokens);
690+
MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens);
689691

690692
while let Some(token_res) = token_stream.next().await {
691693
span.pb_inc(1);

0 commit comments

Comments
 (0)