diff --git a/Cargo.lock b/Cargo.lock index ecdec9a6f..e156534f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6108,11 +6108,14 @@ dependencies = [ "mas-storage", "mas-storage-pg", "rand", + "rand_chacha", + "rustc-hash 2.1.1", "serde", "sqlx", "thiserror 2.0.11", "thiserror-ext", "tokio", + "tokio-util", "tracing", "ulid", "uuid", diff --git a/crates/cli/src/commands/syn2mas.rs b/crates/cli/src/commands/syn2mas.rs index 314ceaeb6..b5d7b4b7b 100644 --- a/crates/cli/src/commands/syn2mas.rs +++ b/crates/cli/src/commands/syn2mas.rs @@ -80,6 +80,7 @@ enum Subcommand { const NUM_WRITER_CONNECTIONS: usize = 8; impl Options { + #[tracing::instrument("cli.syn2mas.run", skip_all)] #[allow(clippy::too_many_lines)] pub async fn run(self, figment: &Figment) -> anyhow::Result { warn!( @@ -173,14 +174,14 @@ impl Options { // Display errors and warnings if !check_errors.is_empty() { - eprintln!("===== Errors ====="); + eprintln!("\n\n===== Errors ====="); eprintln!("These issues prevent migrating from Synapse to MAS right now:\n"); for error in &check_errors { eprintln!("• {error}\n"); } } if !check_warnings.is_empty() { - eprintln!("===== Warnings ====="); + eprintln!("\n\n===== Warnings ====="); eprintln!( "These potential issues should be considered before migrating from Synapse to MAS right now:\n" ); @@ -220,6 +221,7 @@ impl Options { // TODO how should we handle warnings at this stage? + // TODO this dry-run flag should be set to false in real circumstances !!! let reader = SynapseReader::new(&mut syn_conn, true).await?; let mut writer_mas_connections = Vec::with_capacity(NUM_WRITER_CONNECTIONS); for _ in 0..NUM_WRITER_CONNECTIONS { @@ -234,6 +236,7 @@ impl Options { // TODO progress reporting let mas_matrix = MatrixConfig::extract(figment)?; + eprintln!("\n\n"); syn2mas::migrate( reader, writer, diff --git a/crates/syn2mas/Cargo.toml b/crates/syn2mas/Cargo.toml index 5b80b1510..fdc90f0c9 100644 --- a/crates/syn2mas/Cargo.toml +++ b/crates/syn2mas/Cargo.toml @@ -18,13 +18,16 @@ serde.workspace = true thiserror.workspace = true thiserror-ext.workspace = true tokio.workspace = true +tokio-util.workspace = true sqlx.workspace = true chrono.workspace = true compact_str.workspace = true tracing.workspace = true futures-util = "0.3.31" +rustc-hash = "2.1.1" rand.workspace = true +rand_chacha = "0.3.1" uuid = "1.15.1" ulid = { workspace = true, features = ["uuid"] } diff --git a/crates/syn2mas/src/lib.rs b/crates/syn2mas/src/lib.rs index a7a4b72ca..d0d1162fb 100644 --- a/crates/syn2mas/src/lib.rs +++ b/crates/syn2mas/src/lib.rs @@ -8,6 +8,9 @@ mod synapse_reader; mod migration; +type RandomState = rustc_hash::FxBuildHasher; +type HashMap = rustc_hash::FxHashMap; + pub use self::{ mas_writer::{MasWriter, checks::mas_pre_migration_checks, locking::LockedMasDatabase}, migration::migrate, diff --git a/crates/syn2mas/src/mas_writer/checks.rs b/crates/syn2mas/src/mas_writer/checks.rs index 64c140bde..d5b51b510 100644 --- a/crates/syn2mas/src/mas_writer/checks.rs +++ b/crates/syn2mas/src/mas_writer/checks.rs @@ -10,6 +10,7 @@ use thiserror::Error; use thiserror_ext::ContextInto; +use tracing::Instrument as _; use super::{MAS_TABLES_AFFECTED_BY_MIGRATION, is_syn2mas_in_progress, locking::LockedMasDatabase}; @@ -46,7 +47,7 @@ pub enum Error { /// - If any MAS tables involved in the migration are not empty. /// - If we can't check whether syn2mas is already in progress on this database /// or not. -#[tracing::instrument(skip_all)] +#[tracing::instrument(name = "syn2mas.mas_pre_migration_checks", skip_all)] pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) -> Result<(), Error> { if is_syn2mas_in_progress(mas_connection.as_mut()) .await @@ -60,8 +61,11 @@ pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) -> // empty database. for &table in MAS_TABLES_AFFECTED_BY_MIGRATION { - let row_present = sqlx::query(&format!("SELECT 1 AS dummy FROM {table} LIMIT 1")) + let query = format!("SELECT 1 AS dummy FROM {table} LIMIT 1"); + let span = tracing::info_span!("db.query", db.query.text = query); + let row_present = sqlx::query(&query) .fetch_optional(mas_connection.as_mut()) + .instrument(span) .await .into_maybe_not_mas(table)? .is_some(); diff --git a/crates/syn2mas/src/mas_writer/constraint_pausing.rs b/crates/syn2mas/src/mas_writer/constraint_pausing.rs index 6a420888f..36783215f 100644 --- a/crates/syn2mas/src/mas_writer/constraint_pausing.rs +++ b/crates/syn2mas/src/mas_writer/constraint_pausing.rs @@ -3,8 +3,10 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::time::Instant; + use sqlx::PgConnection; -use tracing::debug; +use tracing::{debug, info}; use super::{Error, IntoDatabase}; @@ -109,15 +111,20 @@ pub async fn drop_index(conn: &mut PgConnection, index: &IndexDescription) -> Re /// Restores (recreates) a constraint. /// /// The constraint must not exist prior to this call. +#[tracing::instrument(name = "syn2mas.restore_constraint", skip_all, fields(constraint.name = constraint.name))] pub async fn restore_constraint( conn: &mut PgConnection, constraint: &ConstraintDescription, ) -> Result<(), Error> { + let start = Instant::now(); + let ConstraintDescription { name, table_name, definition, } = &constraint; + info!("rebuilding constraint {name}"); + sqlx::query(&format!( "ALTER TABLE {table_name} ADD CONSTRAINT {name} {definition};" )) @@ -127,13 +134,21 @@ pub async fn restore_constraint( format!("failed to recreate constraint {name} on {table_name} with {definition}") })?; + info!( + "constraint {name} rebuilt in {:.1}s", + Instant::now().duration_since(start).as_secs_f64() + ); + Ok(()) } /// Restores (recreates) a index. /// /// The index must not exist prior to this call. +#[tracing::instrument(name = "syn2mas.restore_index", skip_all, fields(index.name = index.name))] pub async fn restore_index(conn: &mut PgConnection, index: &IndexDescription) -> Result<(), Error> { + let start = Instant::now(); + let IndexDescription { name, table_name, @@ -147,5 +162,10 @@ pub async fn restore_index(conn: &mut PgConnection, index: &IndexDescription) -> format!("failed to recreate index {name} on {table_name} with {definition}") })?; + info!( + "index {name} rebuilt in {:.1}s", + Instant::now().duration_since(start).as_secs_f64() + ); + Ok(()) } diff --git a/crates/syn2mas/src/mas_writer/mod.rs b/crates/syn2mas/src/mas_writer/mod.rs index 4acf21d6f..a56e69980 100644 --- a/crates/syn2mas/src/mas_writer/mod.rs +++ b/crates/syn2mas/src/mas_writer/mod.rs @@ -7,7 +7,14 @@ //! //! This module is responsible for writing new records to MAS' database. -use std::{fmt::Display, net::IpAddr}; +use std::{ + fmt::Display, + net::IpAddr, + sync::{ + Arc, + atomic::{AtomicU32, Ordering}, + }, +}; use chrono::{DateTime, Utc}; use futures_util::{FutureExt, TryStreamExt, future::BoxFuture}; @@ -15,7 +22,7 @@ use sqlx::{Executor, PgConnection, query, query_as}; use thiserror::Error; use thiserror_ext::{Construct, ContextInto}; use tokio::sync::mpsc::{self, Receiver, Sender}; -use tracing::{Level, error, info, warn}; +use tracing::{Instrument, Level, error, info, warn}; use uuid::{NonNilUuid, Uuid}; use self::{ @@ -44,6 +51,9 @@ pub enum Error { #[error("inconsistent database: {0}")] Inconsistent(String), + #[error("bug in syn2mas: write buffers not finished")] + WriteBuffersNotFinished, + #[error("{0}")] Multiple(MultipleErrors), } @@ -109,18 +119,21 @@ impl WriterConnectionPool { match self.connection_rx.recv().await { Some(Ok(mut connection)) => { let connection_tx = self.connection_tx.clone(); - tokio::task::spawn(async move { - let to_return = match task(&mut connection).await { - Ok(()) => Ok(connection), - Err(error) => { - error!("error in writer: {error}"); - Err(error) - } - }; - // This should always succeed in sending unless we're already shutting - // down for some other reason. - let _: Result<_, _> = connection_tx.send(to_return).await; - }); + tokio::task::spawn( + async move { + let to_return = match task(&mut connection).await { + Ok(()) => Ok(connection), + Err(error) => { + error!("error in writer: {error}"); + Err(error) + } + }; + // This should always succeed in sending unless we're already shutting + // down for some other reason. + let _: Result<_, _> = connection_tx.send(to_return).await; + } + .instrument(tracing::debug_span!("spawn_with_connection")), + ); Ok(()) } @@ -188,12 +201,52 @@ impl WriterConnectionPool { } } +/// Small utility to make sure `finish()` is called on all write buffers +/// before committing to the database. +#[derive(Default)] +struct FinishChecker { + counter: Arc, +} + +struct FinishCheckerHandle { + counter: Arc, +} + +impl FinishChecker { + /// Acquire a new handle, for a task that should declare when it has + /// finished. + pub fn handle(&self) -> FinishCheckerHandle { + self.counter.fetch_add(1, Ordering::SeqCst); + FinishCheckerHandle { + counter: Arc::clone(&self.counter), + } + } + + /// Check that all handles have been declared as finished. + pub fn check_all_finished(self) -> Result<(), Error> { + if self.counter.load(Ordering::SeqCst) == 0 { + Ok(()) + } else { + Err(Error::WriteBuffersNotFinished) + } + } +} + +impl FinishCheckerHandle { + /// Declare that the task this handle represents has been finished. + pub fn declare_finished(self) { + self.counter.fetch_sub(1, Ordering::SeqCst); + } +} + pub struct MasWriter { conn: LockedMasDatabase, writer_pool: WriterConnectionPool, indices_to_restore: Vec, constraints_to_restore: Vec, + + write_buffer_finish_checker: FinishChecker, } pub struct MasNewUser { @@ -336,7 +389,7 @@ impl MasWriter { /// /// - If the database connection experiences an error. #[allow(clippy::missing_panics_doc)] // not real - #[tracing::instrument(skip_all)] + #[tracing::instrument(name = "syn2mas.mas_writer.new", skip_all)] pub async fn new( mut conn: LockedMasDatabase, mut writer_connections: Vec, @@ -453,6 +506,7 @@ impl MasWriter { writer_pool: WriterConnectionPool::new(writer_connections), indices_to_restore, constraints_to_restore, + write_buffer_finish_checker: FinishChecker::default(), }) } @@ -520,6 +574,8 @@ impl MasWriter { /// - If the database connection experiences an error. #[tracing::instrument(skip_all)] pub async fn finish(mut self) -> Result { + self.write_buffer_finish_checker.check_all_finished()?; + // Commit all writer transactions to the database. self.writer_pool .finish() @@ -1033,28 +1089,24 @@ type WriteBufferFlusher = /// A buffer for writing rows to the MAS database. /// Generic over the type of rows. -/// -/// # Panics -/// -/// Panics if dropped before `finish()` has been called. pub struct MasWriteBuffer { rows: Vec, flusher: WriteBufferFlusher, - finished: bool, + finish_checker_handle: FinishCheckerHandle, } impl MasWriteBuffer { - pub fn new(flusher: WriteBufferFlusher) -> Self { + pub fn new(writer: &MasWriter, flusher: WriteBufferFlusher) -> Self { MasWriteBuffer { rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE), flusher, - finished: false, + finish_checker_handle: writer.write_buffer_finish_checker.handle(), } } pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> { - self.finished = true; self.flush(writer).await?; + self.finish_checker_handle.declare_finished(); Ok(()) } @@ -1077,12 +1129,6 @@ impl MasWriteBuffer { } } -impl Drop for MasWriteBuffer { - fn drop(&mut self) { - assert!(self.finished, "MasWriteBuffer dropped but not finished!"); - } -} - #[cfg(test)] mod test { use std::collections::{BTreeMap, BTreeSet}; diff --git a/crates/syn2mas/src/migration.rs b/crates/syn2mas/src/migration.rs index 5135a5d80..cd08def9e 100644 --- a/crates/syn2mas/src/migration.rs +++ b/crates/syn2mas/src/migration.rs @@ -11,21 +11,22 @@ //! This module does not implement any of the safety checks that should be run //! *before* the migration. -use std::{collections::HashMap, pin::pin}; +use std::{pin::pin, time::Instant}; use chrono::{DateTime, Utc}; use compact_str::CompactString; -use futures_util::StreamExt as _; +use futures_util::{SinkExt, StreamExt as _, TryFutureExt, TryStreamExt as _}; use mas_storage::Clock; -use rand::RngCore; +use rand::{RngCore, SeedableRng}; use thiserror::Error; use thiserror_ext::ContextInto; -use tracing::Level; +use tokio_util::sync::PollSender; +use tracing::{Instrument as _, Level, info}; use ulid::Ulid; use uuid::{NonNilUuid, Uuid}; use crate::{ - SynapseReader, + HashMap, RandomState, SynapseReader, mas_writer::{ self, MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession, MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser, @@ -54,6 +55,15 @@ pub enum Error { source: ExtractLocalpartError, user: FullUserId, }, + #[error("channel closed")] + ChannelClosed, + + #[error("task failed ({context}): {source}")] + Join { + source: tokio::task::JoinError, + context: String, + }, + #[error("user {user} was not found for migration but a row in {table} was found for them")] MissingUserFromDependentTable { table: String, user: FullUserId }, #[error( @@ -114,7 +124,7 @@ struct MigrationState { /// A mapping of Synapse external ID providers to MAS upstream OAuth 2.0 /// provider ID - provider_id_mapping: HashMap, + provider_id_mapping: std::collections::HashMap, } /// Performs a migration from Synapse's database to MAS' database. @@ -136,14 +146,19 @@ pub async fn migrate( server_name: String, clock: &dyn Clock, rng: &mut impl RngCore, - provider_id_mapping: HashMap, + provider_id_mapping: std::collections::HashMap, ) -> Result<(), Error> { let counts = synapse.count_rows().await.into_synapse("counting users")?; let state = MigrationState { server_name, - users: HashMap::with_capacity(counts.users), - devices_to_compat_sessions: HashMap::with_capacity(counts.devices), + // We oversize the hashmaps, as the estimates are innaccurate, and we would like to avoid + // reallocations. + users: HashMap::with_capacity_and_hasher(counts.users * 9 / 8, RandomState::default()), + devices_to_compat_sessions: HashMap::with_capacity_and_hasher( + counts.devices * 9 / 8, + RandomState::default(), + ), provider_id_mapping, }; @@ -175,82 +190,110 @@ async fn migrate_users( mut state: MigrationState, rng: &mut impl RngCore, ) -> Result<(MasWriter, MigrationState), Error> { - let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users); - let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords); - let mut users_stream = pin!(synapse.read_users()); - - while let Some(user_res) = users_stream.next().await { - let user = user_res.into_synapse("reading user")?; - - // Handling an edge case: some AS users may have invalid localparts containing - // extra `:` characters. These users are ignored and a warning is logged. - if user.appservice_id.is_some() - && user - .name - .0 - .strip_suffix(&format!(":{}", state.server_name)) - .is_some_and(|localpart| localpart.contains(':')) - { - tracing::warn!("AS user {} has invalid localpart, ignoring!", user.name.0); - continue; - } - - let (mas_user, mas_password_opt) = transform_user(&user, &state.server_name, rng)?; + let start = Instant::now(); + + let (tx, mut rx) = tokio::sync::mpsc::channel::(10 * 1024 * 1024); + + let mut rng = rand_chacha::ChaCha8Rng::from_rng(rng).expect("failed to seed rng"); + let task = tokio::spawn( + async move { + let mut user_buffer = MasWriteBuffer::new(&mas, MasWriter::write_users); + let mut password_buffer = MasWriteBuffer::new(&mas, MasWriter::write_passwords); + + while let Some(user) = rx.recv().await { + // Handling an edge case: some AS users may have invalid localparts containing + // extra `:` characters. These users are ignored and a warning is logged. + if user.appservice_id.is_some() + && user + .name + .0 + .strip_suffix(&format!(":{}", state.server_name)) + .is_some_and(|localpart| localpart.contains(':')) + { + tracing::warn!("AS user {} has invalid localpart, ignoring!", user.name.0); + continue; + } + + let (mas_user, mas_password_opt) = + transform_user(&user, &state.server_name, &mut rng)?; + + let mut flags = UserFlags::empty(); + if bool::from(user.admin) { + flags |= UserFlags::IS_SYNAPSE_ADMIN; + } + if bool::from(user.deactivated) { + flags |= UserFlags::IS_DEACTIVATED; + } + if bool::from(user.is_guest) { + flags |= UserFlags::IS_GUEST; + } + if user.appservice_id.is_some() { + flags |= UserFlags::IS_APPSERVICE; + + // Special case for appservice users: we don't insert them into the database + // We just record the user's information in the state and continue + state.users.insert( + CompactString::new(&mas_user.username), + UserInfo { + mas_user_id: None, + flags, + }, + ); + continue; + } + + state.users.insert( + CompactString::new(&mas_user.username), + UserInfo { + mas_user_id: Some(mas_user.user_id), + flags, + }, + ); + + user_buffer + .write(&mut mas, mas_user) + .await + .into_mas("writing user")?; + + if let Some(mas_password) = mas_password_opt { + password_buffer + .write(&mut mas, mas_password) + .await + .into_mas("writing password")?; + } + } + + user_buffer + .finish(&mut mas) + .await + .into_mas("writing users")?; + password_buffer + .finish(&mut mas) + .await + .into_mas("writing passwords")?; - let mut flags = UserFlags::empty(); - if bool::from(user.admin) { - flags |= UserFlags::IS_SYNAPSE_ADMIN; - } - if bool::from(user.deactivated) { - flags |= UserFlags::IS_DEACTIVATED; - } - if bool::from(user.is_guest) { - flags |= UserFlags::IS_GUEST; - } - if user.appservice_id.is_some() { - flags |= UserFlags::IS_APPSERVICE; - - // Special case for appservice users: we don't insert them into the database - // We just record the user's information in the state and continue - state.users.insert( - CompactString::new(&mas_user.username), - UserInfo { - mas_user_id: None, - flags, - }, - ); - continue; + Ok((mas, state)) } + .instrument(tracing::info_span!("ingest_task")), + ); - state.users.insert( - CompactString::new(&mas_user.username), - UserInfo { - mas_user_id: Some(mas_user.user_id), - flags, - }, - ); + // In case this has an error, we still want to join the task, so we look at the + // error later + let res = synapse + .read_users() + .map_err(|e| e.into_synapse("reading users")) + .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed)) + .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error)) + .await; - user_buffer - .write(&mut mas, mas_user) - .await - .into_mas("writing user")?; + let (mas, state) = task.await.into_join("user write task")??; - if let Some(mas_password) = mas_password_opt { - password_buffer - .write(&mut mas, mas_password) - .await - .into_mas("writing password")?; - } - } + res?; - user_buffer - .finish(&mut mas) - .await - .into_mas("writing users")?; - password_buffer - .finish(&mut mas) - .await - .into_mas("writing passwords")?; + info!( + "users migrated in {:.1}s", + Instant::now().duration_since(start).as_secs_f64() + ); Ok((mas, state)) } @@ -262,8 +305,10 @@ async fn migrate_threepids( rng: &mut impl RngCore, state: MigrationState, ) -> Result<(MasWriter, MigrationState), Error> { - let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids); - let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids); + let start = Instant::now(); + + let mut email_buffer = MasWriteBuffer::new(&mas, MasWriter::write_email_threepids); + let mut unsupported_buffer = MasWriteBuffer::new(&mas, MasWriter::write_unsupported_threepids); let mut users_stream = pin!(synapse.read_threepids()); while let Some(threepid_res) = users_stream.next().await { @@ -331,6 +376,11 @@ async fn migrate_threepids( .await .into_mas("writing unsupported threepids")?; + info!( + "third-party IDs migrated in {:.1}s", + Instant::now().duration_since(start).as_secs_f64() + ); + Ok((mas, state)) } @@ -345,7 +395,9 @@ async fn migrate_external_ids( rng: &mut impl RngCore, state: MigrationState, ) -> Result<(MasWriter, MigrationState), Error> { - let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links); + let start = Instant::now(); + + let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_upstream_oauth_links); let mut extids_stream = pin!(synapse.read_user_external_ids()); while let Some(extid_res) = extids_stream.next().await { @@ -400,7 +452,12 @@ async fn migrate_external_ids( write_buffer .finish(&mut mas) .await - .into_mas("writing threepids")?; + .into_mas("writing upstream links")?; + + info!( + "upstream links (external IDs) migrated in {:.1}s", + Instant::now().duration_since(start).as_secs_f64() + ); Ok((mas, state)) } @@ -420,92 +477,121 @@ async fn migrate_devices( rng: &mut impl RngCore, mut state: MigrationState, ) -> Result<(MasWriter, MigrationState), Error> { - let mut devices_stream = pin!(synapse.read_devices()); - let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions); - - while let Some(device_res) = devices_stream.next().await { - let SynapseDevice { - user_id: synapse_user_id, - device_id, - display_name, - last_seen, - ip, - user_agent, - } = device_res.into_synapse("reading Synapse device")?; - - let username = synapse_user_id - .extract_localpart(&state.server_name) - .into_extract_localpart(synapse_user_id.clone())? - .to_owned(); - let Some(user_infos) = state.users.get(username.as_str()).copied() else { - return Err(Error::MissingUserFromDependentTable { - table: "devices".to_owned(), - user: synapse_user_id, - }); - }; - - let Some(mas_user_id) = user_infos.mas_user_id else { - continue; - }; + let start = Instant::now(); + + let (tx, mut rx) = tokio::sync::mpsc::channel(10 * 1024 * 1024); + + let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng"); + let task = tokio::spawn( + async move { + let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions); + + while let Some(device) = rx.recv().await { + let SynapseDevice { + user_id: synapse_user_id, + device_id, + display_name, + last_seen, + ip, + user_agent, + } = device; + let username = synapse_user_id + .extract_localpart(&state.server_name) + .into_extract_localpart(synapse_user_id.clone())? + .to_owned(); + let Some(user_infos) = state.users.get(username.as_str()).copied() else { + return Err(Error::MissingUserFromDependentTable { + table: "devices".to_owned(), + user: synapse_user_id, + }); + }; + + let Some(mas_user_id) = user_infos.mas_user_id else { + continue; + }; + + if user_infos.flags.is_deactivated() + || user_infos.flags.is_guest() + || user_infos.flags.is_appservice() + { + continue; + } + + let session_id = *state + .devices_to_compat_sessions + .entry((mas_user_id, CompactString::new(&device_id))) + .or_insert_with(|| + // We don't have a creation time for this device (as it has no access token), + // so use now as a least-evil fallback. + Ulid::with_source(&mut rng).into()); + let created_at = Ulid::from(session_id).datetime().into(); + + // As we're using a real IP type in the MAS database, it is possible + // that we encounter invalid IP addresses in the Synapse database. + // In that case, we should ignore them, but still log a warning. + // One special case: Synapse will record '-' as IP in some cases, we don't want + // to log about those + let last_active_ip = ip.filter(|ip| ip != "-").and_then(|ip| { + ip.parse() + .map_err(|e| { + tracing::warn!( + error = &e as &dyn std::error::Error, + mxid = %synapse_user_id, + %device_id, + %ip, + "Failed to parse device IP, ignoring" + ); + }) + .ok() + }); + + // TODO skip access tokens for deactivated users + write_buffer + .write( + &mut mas, + MasNewCompatSession { + session_id, + user_id: mas_user_id, + device_id: Some(device_id), + human_name: display_name, + created_at, + is_synapse_admin: user_infos.flags.is_synapse_admin(), + last_active_at: last_seen.map(DateTime::from), + last_active_ip, + user_agent, + }, + ) + .await + .into_mas("writing compat sessions")?; + } + + write_buffer + .finish(&mut mas) + .await + .into_mas("writing compat sessions")?; - if user_infos.flags.is_deactivated() - || user_infos.flags.is_guest() - || user_infos.flags.is_appservice() - { - continue; + Ok((mas, state)) } + .instrument(tracing::info_span!("ingest_task")), + ); - let session_id = *state - .devices_to_compat_sessions - .entry((mas_user_id, CompactString::new(&device_id))) - .or_insert_with(|| - // We don't have a creation time for this device (as it has no access token), - // so use now as a least-evil fallback. - Ulid::with_source(rng).into()); - let created_at = Ulid::from(session_id).datetime().into(); - - // As we're using a real IP type in the MAS database, it is possible - // that we encounter invalid IP addresses in the Synapse database. - // In that case, we should ignore them, but still log a warning. - // One special case: Synapse will record '-' as IP in some cases, we don't want - // to log about those - let last_active_ip = ip.filter(|ip| ip != "-").and_then(|ip| { - ip.parse() - .map_err(|e| { - tracing::warn!( - error = &e as &dyn std::error::Error, - mxid = %synapse_user_id, - %device_id, - %ip, - "Failed to parse device IP, ignoring" - ); - }) - .ok() - }); + // In case this has an error, we still want to join the task, so we look at the + // error later + let res = synapse + .read_devices() + .map_err(|e| e.into_synapse("reading devices")) + .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed)) + .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error)) + .await; - write_buffer - .write( - &mut mas, - MasNewCompatSession { - session_id, - user_id: mas_user_id, - device_id: Some(device_id), - human_name: display_name, - created_at, - is_synapse_admin: user_infos.flags.is_synapse_admin(), - last_active_at: last_seen.map(DateTime::from), - last_active_ip, - user_agent, - }, - ) - .await - .into_mas("writing compat sessions")?; - } + let (mas, state) = task.await.into_join("device write task")??; - write_buffer - .finish(&mut mas) - .await - .into_mas("writing compat sessions")?; + res?; + + info!( + "devices migrated in {:.1}s", + Instant::now().duration_since(start).as_secs_f64() + ); Ok((mas, state)) } @@ -520,106 +606,136 @@ async fn migrate_unrefreshable_access_tokens( rng: &mut impl RngCore, mut state: MigrationState, ) -> Result<(MasWriter, MigrationState), Error> { - let mut token_stream = pin!(synapse.read_unrefreshable_access_tokens()); - let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens); - let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions); - - while let Some(token_res) = token_stream.next().await { - let SynapseAccessToken { - user_id: synapse_user_id, - device_id, - token, - valid_until_ms, - last_validated, - } = token_res.into_synapse("reading Synapse access token")?; - - let username = synapse_user_id - .extract_localpart(&state.server_name) - .into_extract_localpart(synapse_user_id.clone())? - .to_owned(); - let Some(user_infos) = state.users.get(username.as_str()).copied() else { - return Err(Error::MissingUserFromDependentTable { - table: "access_tokens".to_owned(), - user: synapse_user_id, - }); - }; - - let Some(mas_user_id) = user_infos.mas_user_id else { - continue; - }; - - if user_infos.flags.is_deactivated() - || user_infos.flags.is_guest() - || user_infos.flags.is_appservice() - { - continue; - } - - // It's not always accurate, but last_validated is *often* the creation time of - // the device If we don't have one, then use the current time as a - // fallback. - let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from); - - let session_id = if let Some(device_id) = device_id { - // Use the existing device_id if this is the second token for a device - *state - .devices_to_compat_sessions - .entry((mas_user_id, CompactString::new(&device_id))) - .or_insert_with(|| { - Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)) - }) - } else { - // If this is a deviceless access token, create a deviceless compat session - // for it (since otherwise we won't create one whilst migrating devices) - let deviceless_session_id = - Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)); - + let start = Instant::now(); + + let (tx, mut rx) = tokio::sync::mpsc::channel(10 * 1024 * 1024); + + let now = clock.now(); + let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng"); + let task = tokio::spawn( + async move { + let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens); + let mut deviceless_session_write_buffer = + MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions); + + while let Some(token) = rx.recv().await { + let SynapseAccessToken { + user_id: synapse_user_id, + device_id, + token, + valid_until_ms, + last_validated, + } = token; + let username = synapse_user_id + .extract_localpart(&state.server_name) + .into_extract_localpart(synapse_user_id.clone())? + .to_owned(); + let Some(user_infos) = state.users.get(username.as_str()).copied() else { + return Err(Error::MissingUserFromDependentTable { + table: "access_tokens".to_owned(), + user: synapse_user_id, + }); + }; + + let Some(mas_user_id) = user_infos.mas_user_id else { + continue; + }; + + if user_infos.flags.is_deactivated() + || user_infos.flags.is_guest() + || user_infos.flags.is_appservice() + { + continue; + } + + // It's not always accurate, but last_validated is *often* the creation time of + // the device If we don't have one, then use the current time as a + // fallback. + let created_at = last_validated.map_or_else(|| now, DateTime::from); + + let session_id = if let Some(device_id) = device_id { + // Use the existing device_id if this is the second token for a device + *state + .devices_to_compat_sessions + .entry((mas_user_id, CompactString::new(&device_id))) + .or_insert_with(|| { + Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng)) + }) + } else { + // If this is a deviceless access token, create a deviceless compat session + // for it (since otherwise we won't create one whilst migrating devices) + let deviceless_session_id = + Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng)); + + deviceless_session_write_buffer + .write( + &mut mas, + MasNewCompatSession { + session_id: deviceless_session_id, + user_id: mas_user_id, + device_id: None, + human_name: None, + created_at, + is_synapse_admin: false, + last_active_at: None, + last_active_ip: None, + user_agent: None, + }, + ) + .await + .into_mas("failed to write deviceless compat sessions")?; + + deviceless_session_id + }; + + let token_id = + Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng)); + + write_buffer + .write( + &mut mas, + MasNewCompatAccessToken { + token_id, + session_id, + access_token: token, + created_at, + expires_at: valid_until_ms.map(DateTime::from), + }, + ) + .await + .into_mas("writing compat access tokens")?; + } + write_buffer + .finish(&mut mas) + .await + .into_mas("writing compat access tokens")?; deviceless_session_write_buffer - .write( - &mut mas, - MasNewCompatSession { - session_id: deviceless_session_id, - user_id: mas_user_id, - device_id: None, - human_name: None, - created_at, - is_synapse_admin: false, - last_active_at: None, - last_active_ip: None, - user_agent: None, - }, - ) + .finish(&mut mas) .await - .into_mas("failed to write deviceless compat sessions")?; + .into_mas("writing deviceless compat sessions")?; - deviceless_session_id - }; + Ok((mas, state)) + } + .instrument(tracing::info_span!("ingest_task")), + ); - let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)); + // In case this has an error, we still want to join the task, so we look at the + // error later + let res = synapse + .read_unrefreshable_access_tokens() + .map_err(|e| e.into_synapse("reading tokens")) + .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed)) + .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error)) + .await; - write_buffer - .write( - &mut mas, - MasNewCompatAccessToken { - token_id, - session_id, - access_token: token, - created_at, - expires_at: valid_until_ms.map(DateTime::from), - }, - ) - .await - .into_mas("writing compat access tokens")?; - } + let (mas, state) = task.await.into_join("token write task")??; - write_buffer - .finish(&mut mas) - .await - .into_mas("writing compat access tokens")?; - deviceless_session_write_buffer - .finish(&mut mas) - .await - .into_mas("writing deviceless compat sessions")?; + res?; + + info!( + "non-refreshable access tokens migrated in {:.1}s", + Instant::now().duration_since(start).as_secs_f64() + ); Ok((mas, state)) } @@ -634,10 +750,13 @@ async fn migrate_refreshable_token_pairs( rng: &mut impl RngCore, mut state: MigrationState, ) -> Result<(MasWriter, MigrationState), Error> { + let start = Instant::now(); + let mut token_stream = pin!(synapse.read_refreshable_token_pairs()); - let mut access_token_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens); + let mut access_token_write_buffer = + MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens); let mut refresh_token_write_buffer = - MasWriteBuffer::new(MasWriter::write_compat_refresh_tokens); + MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens); while let Some(token_res) = token_stream.next().await { let SynapseRefreshableTokenPair { @@ -723,6 +842,11 @@ async fn migrate_refreshable_token_pairs( .await .into_mas("writing compat refresh tokens")?; + info!( + "refreshable token pairs migrated in {:.1}s", + Instant::now().duration_since(start).as_secs_f64() + ); + Ok((mas, state)) } diff --git a/crates/syn2mas/src/synapse_reader/checks.rs b/crates/syn2mas/src/synapse_reader/checks.rs index b2495f327..83f31bbcf 100644 --- a/crates/syn2mas/src/synapse_reader/checks.rs +++ b/crates/syn2mas/src/synapse_reader/checks.rs @@ -48,21 +48,11 @@ pub enum CheckError { )] PasswordSchemeWrongPepper, - #[error( - "Synapse database contains {num_guests} guests which aren't supported by MAS. See https://github.com/element-hq/matrix-authentication-service/issues/1445" - )] - GuestsInDatabase { num_guests: i64 }, - #[error( "Guest support is enabled in the Synapse configuration. Guests aren't supported by MAS, but if you don't have any then you could disable the option. See https://github.com/element-hq/matrix-authentication-service/issues/1445" )] GuestsEnabled, - #[error( - "Synapse database contains {num_non_email_3pids} non-email 3PIDs (probably phone numbers), which are not supported by MAS." - )] - NonEmailThreepidsInDatabase { num_non_email_3pids: i64 }, - #[error( "Synapse config has `enable_3pid_changes` explicitly enabled, which must be disabled or removed." )] @@ -125,6 +115,16 @@ pub enum CheckWarning { "Synapse config has a registration CAPTCHA enabled, but no CAPTCHA has been configured in MAS. You may wish to manually configure this." )] ShouldPortRegistrationCaptcha, + + #[error( + "Synapse database contains {num_guests} guests which will be migrated are not supported by MAS. See https://github.com/element-hq/matrix-authentication-service/issues/1445" + )] + GuestsInDatabase { num_guests: i64 }, + + #[error( + "Synapse database contains {num_non_email_3pids} non-email 3PIDs (probably phone numbers), which will be migrated but are not supported by MAS." + )] + NonEmailThreepidsInDatabase { num_non_email_3pids: i64 }, } /// Check that the Synapse configuration is sane for migration. @@ -140,15 +140,6 @@ pub fn synapse_config_check(synapse_config: &Config) -> (Vec, Vec< warnings.push(CheckWarning::DisableUserConsentAfterMigration); } - // TODO check the settings directly against the MAS settings - for provider in synapse_config.all_oidc_providers().values() { - if let Some(ref issuer) = provider.issuer { - warnings.push(CheckWarning::UpstreamOidcProvider { - issuer: issuer.clone(), - }); - } - } - // TODO provide guidance on migrating these if synapse_config.cas_config.enabled { warnings.push(CheckWarning::ExternalAuthSystem("CAS")); @@ -269,13 +260,13 @@ pub async fn synapse_database_check( } let mut errors = Vec::new(); - let warnings = Vec::new(); + let mut warnings = Vec::new(); let num_guests: i64 = query_scalar("SELECT COUNT(1) FROM users WHERE is_guest <> 0") .fetch_one(&mut *synapse_connection) .await?; if num_guests > 0 { - errors.push(CheckError::GuestsInDatabase { num_guests }); + warnings.push(CheckWarning::GuestsInDatabase { num_guests }); } let num_non_email_3pids: i64 = @@ -283,7 +274,7 @@ pub async fn synapse_database_check( .fetch_one(&mut *synapse_connection) .await?; if num_non_email_3pids > 0 { - errors.push(CheckError::NonEmailThreepidsInDatabase { + warnings.push(CheckWarning::NonEmailThreepidsInDatabase { num_non_email_3pids, }); } diff --git a/crates/syn2mas/src/synapse_reader/mod.rs b/crates/syn2mas/src/synapse_reader/mod.rs index 6646af1b1..54333eb44 100644 --- a/crates/syn2mas/src/synapse_reader/mod.rs +++ b/crates/syn2mas/src/synapse_reader/mod.rs @@ -336,28 +336,31 @@ impl<'conn> SynapseReader<'conn> { /// /// - An underlying database error pub async fn count_rows(&mut self) -> Result { - let users: usize = sqlx::query_scalar::<_, i64>( + // We don't get to filter out application service users by using this estimate, + // which is a shame, but on a large database this is way faster. + // On matrix.org, counting users and devices properly takes around 1m10s, + // which is unnecessary extra downtime during the migration, just to + // show a more accurate progress bar and size a hash map accurately. + let users = sqlx::query_scalar::<_, i64>( " - SELECT COUNT(1) FROM users - WHERE appservice_id IS NULL + SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'users'::regclass; ", ) .fetch_one(&mut *self.txn) .await - .into_database("counting Synapse users")? + .into_database("estimating count of users")? .max(0) .try_into() .unwrap_or(usize::MAX); let devices = sqlx::query_scalar::<_, i64>( " - SELECT COUNT(1) FROM devices - WHERE NOT hidden + SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'devices'::regclass; ", ) .fetch_one(&mut *self.txn) .await - .into_database("counting Synapse devices")? + .into_database("estimating count of devices")? .max(0) .try_into() .unwrap_or(usize::MAX); @@ -427,6 +430,12 @@ impl<'conn> SynapseReader<'conn> { /// Reads unrefreshable access tokens from the Synapse database. /// This does not include access tokens used for puppetting users, as those /// are not supported by MAS. + /// + /// This also excludes access tokens whose referenced device ID does not + /// exist, except for deviceless access tokens. + /// (It's unclear what mechanism led to these, but since Synapse has no + /// foreign key constraints and is not consistently atomic about this, + /// it should be no surprise really) pub fn read_unrefreshable_access_tokens( &mut self, ) -> impl Stream> + '_ { @@ -435,7 +444,15 @@ impl<'conn> SynapseReader<'conn> { SELECT at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated FROM access_tokens at0 + INNER JOIN devices USING (user_id, device_id) WHERE at0.puppets_user_id IS NULL AND at0.refresh_token_id IS NULL + + UNION + + SELECT + at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated + FROM access_tokens at0 + WHERE at0.puppets_user_id IS NULL AND at0.refresh_token_id IS NULL AND at0.device_id IS NULL ", ) .fetch(&mut *self.txn) @@ -459,7 +476,8 @@ impl<'conn> SynapseReader<'conn> { SELECT rt0.user_id, rt0.device_id, at0.token AS access_token, rt0.token AS refresh_token, at0.valid_until_ms, at0.last_validated FROM refresh_tokens rt0 - LEFT JOIN access_tokens at0 ON at0.refresh_token_id = rt0.id AND at0.user_id = rt0.user_id AND at0.device_id = rt0.device_id + INNER JOIN devices USING (device_id) + INNER JOIN access_tokens at0 ON at0.refresh_token_id = rt0.id AND at0.user_id = rt0.user_id AND at0.device_id = rt0.device_id LEFT JOIN access_tokens at1 ON at1.refresh_token_id = rt0.next_token_id WHERE NOT at1.used OR at1.used IS NULL ", @@ -552,7 +570,10 @@ mod test { assert_debug_snapshot!(devices); } - #[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice", "access_token_alice"))] + #[sqlx::test( + migrator = "MIGRATOR", + fixtures("user_alice", "devices_alice", "access_token_alice") + )] async fn test_read_access_token(pool: PgPool) { let mut conn = pool.acquire().await.expect("failed to get connection"); let mut reader = SynapseReader::new(&mut conn, false) @@ -571,7 +592,7 @@ mod test { /// Tests that puppetting access tokens are ignored. #[sqlx::test( migrator = "MIGRATOR", - fixtures("user_alice", "access_token_alice_with_puppet") + fixtures("user_alice", "devices_alice", "access_token_alice_with_puppet") )] async fn test_read_access_token_puppet(pool: PgPool) { let mut conn = pool.acquire().await.expect("failed to get connection"); @@ -590,7 +611,7 @@ mod test { #[sqlx::test( migrator = "MIGRATOR", - fixtures("user_alice", "access_token_alice_with_refresh_token") + fixtures("user_alice", "devices_alice", "access_token_alice_with_refresh_token") )] async fn test_read_access_and_refresh_tokens(pool: PgPool) { let mut conn = pool.acquire().await.expect("failed to get connection"); @@ -619,7 +640,11 @@ mod test { #[sqlx::test( migrator = "MIGRATOR", - fixtures("user_alice", "access_token_alice_with_unused_refresh_token") + fixtures( + "user_alice", + "devices_alice", + "access_token_alice_with_unused_refresh_token" + ) )] async fn test_read_access_and_unused_refresh_tokens(pool: PgPool) { let mut conn = pool.acquire().await.expect("failed to get connection");