Skip to content

Commit f64c902

Browse files
committed
Replace panic on unfinished buffers with an error
This doesn't get clobbered by the progress bar
1 parent 1f8d491 commit f64c902

File tree

2 files changed

+70
-25
lines changed

2 files changed

+70
-25
lines changed

crates/syn2mas/src/mas_writer/mod.rs

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
//!
88
//! This module is responsible for writing new records to MAS' database.
99
10-
use std::{fmt::Display, net::IpAddr};
10+
use std::{
11+
fmt::Display,
12+
net::IpAddr,
13+
sync::{
14+
atomic::{AtomicU32, Ordering},
15+
Arc,
16+
},
17+
};
1118

1219
use chrono::{DateTime, Utc};
1320
use futures_util::{future::BoxFuture, FutureExt, TryStreamExt};
@@ -44,6 +51,9 @@ pub enum Error {
4451
#[error("inconsistent database: {0}")]
4552
Inconsistent(String),
4653

54+
#[error("bug in syn2mas: write buffers not finished")]
55+
WriteBuffersNotFinished,
56+
4757
#[error("{0}")]
4858
Multiple(MultipleErrors),
4959
}
@@ -185,12 +195,52 @@ impl WriterConnectionPool {
185195
}
186196
}
187197

198+
/// Small utility to make sure `finish()` is called on all write buffers
199+
/// before committing to the database.
200+
#[derive(Default)]
201+
struct FinishChecker {
202+
counter: Arc<AtomicU32>,
203+
}
204+
205+
struct FinishCheckerHandle {
206+
counter: Arc<AtomicU32>,
207+
}
208+
209+
impl FinishChecker {
210+
/// Acquire a new handle, for a task that should declare when it has
211+
/// finished.
212+
pub fn handle(&self) -> FinishCheckerHandle {
213+
self.counter.fetch_add(1, Ordering::SeqCst);
214+
FinishCheckerHandle {
215+
counter: Arc::clone(&self.counter),
216+
}
217+
}
218+
219+
/// Check that all handles have been declared as finished.
220+
pub fn check_all_finished(self) -> Result<(), Error> {
221+
if self.counter.load(Ordering::SeqCst) == 0 {
222+
Ok(())
223+
} else {
224+
Err(Error::WriteBuffersNotFinished)
225+
}
226+
}
227+
}
228+
229+
impl FinishCheckerHandle {
230+
/// Declare that the task this handle represents has been finished.
231+
pub fn declare_finished(self) {
232+
self.counter.fetch_sub(1, Ordering::SeqCst);
233+
}
234+
}
235+
188236
pub struct MasWriter<'c> {
189237
conn: LockedMasDatabase<'c>,
190238
writer_pool: WriterConnectionPool,
191239

192240
indices_to_restore: Vec<IndexDescription>,
193241
constraints_to_restore: Vec<ConstraintDescription>,
242+
243+
write_buffer_finish_checker: FinishChecker,
194244
}
195245

196246
pub struct MasNewUser {
@@ -449,6 +499,7 @@ impl<'conn> MasWriter<'conn> {
449499
writer_pool: WriterConnectionPool::new(writer_connections),
450500
indices_to_restore,
451501
constraints_to_restore,
502+
write_buffer_finish_checker: FinishChecker::default(),
452503
})
453504
}
454505

@@ -515,6 +566,8 @@ impl<'conn> MasWriter<'conn> {
515566
/// - If the database connection experiences an error.
516567
#[tracing::instrument(skip_all)]
517568
pub async fn finish(mut self) -> Result<(), Error> {
569+
self.write_buffer_finish_checker.check_all_finished()?;
570+
518571
// Commit all writer transactions to the database.
519572
self.writer_pool
520573
.finish()
@@ -1027,28 +1080,24 @@ type WriteBufferFlusher<'conn, T> =
10271080

10281081
/// A buffer for writing rows to the MAS database.
10291082
/// Generic over the type of rows.
1030-
///
1031-
/// # Panics
1032-
///
1033-
/// Panics if dropped before `finish()` has been called.
10341083
pub struct MasWriteBuffer<'conn, T> {
10351084
rows: Vec<T>,
10361085
flusher: WriteBufferFlusher<'conn, T>,
1037-
finished: bool,
1086+
finish_checker_handle: FinishCheckerHandle,
10381087
}
10391088

10401089
impl<'conn, T> MasWriteBuffer<'conn, T> {
1041-
pub fn new(flusher: WriteBufferFlusher<'conn, T>) -> Self {
1090+
pub fn new(writer: &MasWriter, flusher: WriteBufferFlusher<'conn, T>) -> Self {
10421091
MasWriteBuffer {
10431092
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
10441093
flusher,
1045-
finished: false,
1094+
finish_checker_handle: writer.write_buffer_finish_checker.handle(),
10461095
}
10471096
}
10481097

10491098
pub async fn finish(mut self, writer: &mut MasWriter<'conn>) -> Result<(), Error> {
1050-
self.finished = true;
10511099
self.flush(writer).await?;
1100+
self.finish_checker_handle.declare_finished();
10521101
Ok(())
10531102
}
10541103

@@ -1071,12 +1120,6 @@ impl<'conn, T> MasWriteBuffer<'conn, T> {
10711120
}
10721121
}
10731122

1074-
impl<T> Drop for MasWriteBuffer<'_, T> {
1075-
fn drop(&mut self) {
1076-
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
1077-
}
1078-
}
1079-
10801123
#[cfg(test)]
10811124
mod test {
10821125
use std::collections::{BTreeMap, BTreeSet};

crates/syn2mas/src/migration.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ async fn migrate_users(
249249
span.pb_set_style(&ProgressStyle::default_bar());
250250
span.pb_set_length(user_count_hint as u64);
251251

252-
let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users);
253-
let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords);
252+
let mut user_buffer = MasWriteBuffer::new(mas, MasWriter::write_users);
253+
let mut password_buffer = MasWriteBuffer::new(mas, MasWriter::write_passwords);
254254
let mut users_stream = pin!(synapse.read_users());
255255
// TODO is 1:1 capacity enough for a hashmap?
256256
let mut user_localparts_to_uuid = HashMap::with_capacity(user_count_hint);
@@ -309,8 +309,8 @@ async fn migrate_threepids(
309309
span.pb_set_style(&ProgressStyle::default_bar());
310310
span.pb_set_length(count_hint);
311311

312-
let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids);
313-
let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids);
312+
let mut email_buffer = MasWriteBuffer::new(mas, MasWriter::write_email_threepids);
313+
let mut unsupported_buffer = MasWriteBuffer::new(mas, MasWriter::write_unsupported_threepids);
314314
let mut users_stream = pin!(synapse.read_threepids());
315315

316316
while let Some(threepid_res) = users_stream.next().await {
@@ -400,7 +400,7 @@ async fn migrate_external_ids(
400400
span.pb_set_style(&ProgressStyle::default_bar());
401401
span.pb_set_length(count_hint);
402402

403-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links);
403+
let mut write_buffer = MasWriteBuffer::new(mas, MasWriter::write_upstream_oauth_links);
404404
let mut extids_stream = pin!(synapse.read_user_external_ids());
405405

406406
while let Some(extid_res) = extids_stream.next().await {
@@ -486,7 +486,7 @@ async fn migrate_devices(
486486
span.pb_set_length(count_hint);
487487

488488
let mut devices_stream = pin!(synapse.read_devices());
489-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
489+
let mut write_buffer = MasWriteBuffer::new(mas, MasWriter::write_compat_sessions);
490490

491491
while let Some(device_res) = devices_stream.next().await {
492492
span.pb_inc(1);
@@ -569,8 +569,9 @@ async fn migrate_unrefreshable_access_tokens(
569569
span.pb_set_length(count_hint);
570570

571571
let mut token_stream = pin!(synapse.read_unrefreshable_access_tokens());
572-
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
573-
let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
572+
let mut write_buffer = MasWriteBuffer::new(mas, MasWriter::write_compat_access_tokens);
573+
let mut deviceless_session_write_buffer =
574+
MasWriteBuffer::new(mas, MasWriter::write_compat_sessions);
574575

575576
while let Some(token_res) = token_stream.next().await {
576577
span.pb_inc(1);
@@ -685,9 +686,10 @@ async fn migrate_refreshable_token_pairs(
685686
span.pb_set_length(count_hint);
686687

687688
let mut token_stream = pin!(synapse.read_refreshable_token_pairs());
688-
let mut access_token_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
689+
let mut access_token_write_buffer =
690+
MasWriteBuffer::new(mas, MasWriter::write_compat_access_tokens);
689691
let mut refresh_token_write_buffer =
690-
MasWriteBuffer::new(MasWriter::write_compat_refresh_tokens);
692+
MasWriteBuffer::new(mas, MasWriter::write_compat_refresh_tokens);
691693

692694
while let Some(token_res) = token_stream.next().await {
693695
span.pb_inc(1);

0 commit comments

Comments
 (0)