Skip to content

Commit c9426f8

Browse files
committed
Pipe stream directly in the channel
1 parent 0e7b245 commit c9426f8

File tree

3 files changed

+120
-99
lines changed

3 files changed

+120
-99
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/syn2mas/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ serde.workspace = true
1919
thiserror.workspace = true
2020
thiserror-ext.workspace = true
2121
tokio.workspace = true
22+
tokio-util.workspace = true
2223
sqlx.workspace = true
2324
chrono.workspace = true
2425
compact_str.workspace = true

crates/syn2mas/src/migration.rs

Lines changed: 118 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ use std::{pin::pin, time::Instant};
2020

2121
use chrono::{DateTime, Utc};
2222
use compact_str::CompactString;
23-
use futures_util::StreamExt as _;
23+
use futures_util::{SinkExt, StreamExt as _, TryStreamExt as _};
2424
use mas_storage::Clock;
2525
use rand::{RngCore, SeedableRng};
2626
use thiserror::Error;
2727
use thiserror_ext::ContextInto;
28+
use tokio_util::sync::PollSender;
2829
use tracing::{info, Level, Span};
2930
use tracing_indicatif::span_ext::IndicatifSpanExt;
3031
use ulid::Ulid;
@@ -61,6 +62,8 @@ pub enum Error {
6162
source: ExtractLocalpartError,
6263
user: FullUserId,
6364
},
65+
#[error("channel closed")]
66+
ChannelClosed,
6467
#[error("user {user} was not found for migration but a row in {table} was found for them")]
6568
MissingUserFromDependentTable { table: String, user: FullUserId },
6669
#[error("missing a mapping for the auth provider with ID {synapse_id:?} (used by {user} and maybe other users)")]
@@ -158,25 +161,27 @@ pub async fn migrate(
158161

159162
span.pb_set_message("migrating user rows");
160163
span.pb_inc(1);
161-
let (mut state, mut mas) = migrate_users(&mut synapse, mas, counts.users, state, rng).await?;
164+
let (state, mut mas) = migrate_users(&mut synapse, mas, counts.users, state, rng).await?;
162165

163166
span.pb_set_message("migrating threepids");
164167
span.pb_inc(1);
165168
migrate_threepids(&mut synapse, &mut mas, counts.threepids, rng, &state).await?;
166169
span.pb_set_message("migrating user external IDs");
167170
span.pb_inc(1);
168171
migrate_external_ids(&mut synapse, &mut mas, counts.external_ids, rng, &state).await?;
172+
169173
span.pb_set_message("migrating access tokens");
170174
span.pb_inc(1);
171-
migrate_unrefreshable_access_tokens(
175+
let (mut state, mut mas) = migrate_unrefreshable_access_tokens(
172176
&mut synapse,
173-
&mut mas,
177+
mas,
174178
counts.access_tokens,
175179
clock,
176180
rng,
177-
&mut state,
181+
state,
178182
)
179183
.await?;
184+
180185
span.pb_set_message("migrating refresh tokens");
181186
span.pb_inc(1);
182187
migrate_refreshable_token_pairs(
@@ -188,6 +193,7 @@ pub async fn migrate(
188193
&mut state,
189194
)
190195
.await?;
196+
191197
span.pb_set_message("migrating devices");
192198
span.pb_inc(1);
193199
migrate_devices(&mut synapse, &mut mas, counts.devices, rng, &mut state).await?;
@@ -221,7 +227,6 @@ async fn migrate_users(
221227
) -> Result<(MigrationState, MasWriter), Error> {
222228
let start = Instant::now();
223229

224-
let mut users_stream = pin!(synapse.read_users().with_progress_bar(count_hint, 10_000));
225230
let (tx, mut rx) = tokio::sync::mpsc::channel(1024 * 1024);
226231

227232
let mut rng = rand_chacha::ChaCha8Rng::from_rng(rng).expect("failed to seed rng");
@@ -268,10 +273,12 @@ async fn migrate_users(
268273
Ok((state, mas))
269274
});
270275

271-
while let Some(user_res) = users_stream.next().await {
272-
let user = user_res.into_synapse("reading user")?;
273-
tx.send(user).await.expect("failed to send in channel");
274-
}
276+
synapse
277+
.read_users()
278+
.with_progress_bar(count_hint, 10_000)
279+
.map_err(|e| e.into_synapse("reading users"))
280+
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
281+
.await?;
275282

276283
let (state, mas) = task.await.expect("task panicked")?;
277284

@@ -569,124 +576,136 @@ async fn migrate_devices(
569576
/// Migrates unrefreshable access tokens (those without an associated refresh
570577
/// token). Some of these may be deviceless.
571578
#[tracing::instrument(skip_all, fields(indicatif.pb_show), level = Level::INFO)]
572-
#[allow(clippy::too_many_arguments)]
579+
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
573580
async fn migrate_unrefreshable_access_tokens(
574581
synapse: &mut SynapseReader<'_>,
575-
mas: &mut MasWriter,
582+
mut mas: MasWriter,
576583
count_hint: usize,
577584
clock: &dyn Clock,
578585
rng: &mut impl RngCore,
579-
state: &mut MigrationState,
580-
) -> Result<(), Error> {
586+
mut state: MigrationState,
587+
) -> Result<(MigrationState, MasWriter), Error> {
581588
let start = Instant::now();
582589

583-
let mut token_stream = pin!(synapse
584-
.read_unrefreshable_access_tokens()
585-
.with_progress_bar(count_hint, 10_000));
586-
let mut write_buffer = MasWriteBuffer::new(mas, MasWriter::write_compat_access_tokens);
587-
let mut deviceless_session_write_buffer =
588-
MasWriteBuffer::new(mas, MasWriter::write_compat_sessions);
589-
590-
while let Some(token_res) = token_stream.next().await {
591-
let SynapseAccessToken {
592-
user_id: synapse_user_id,
593-
device_id,
594-
token,
595-
valid_until_ms,
596-
last_validated,
597-
} = token_res.into_synapse("reading Synapse access token")?;
590+
let (tx, mut rx) = tokio::sync::mpsc::channel(1024 * 1024);
598591

599-
let username = synapse_user_id
600-
.extract_localpart(&state.server_name)
601-
.into_extract_localpart(synapse_user_id.clone())?
602-
.to_owned();
603-
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
604-
if true || is_likely_appservice(&username) {
605-
// HACK can we do anything better
592+
let now = clock.now();
593+
let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
594+
let task = tokio::spawn(async move {
595+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
596+
let mut deviceless_session_write_buffer =
597+
MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
598+
599+
while let Some(token) = rx.recv().await {
600+
let SynapseAccessToken {
601+
user_id: synapse_user_id,
602+
device_id,
603+
token,
604+
valid_until_ms,
605+
last_validated,
606+
} = token;
607+
let username = synapse_user_id
608+
.extract_localpart(&state.server_name)
609+
.into_extract_localpart(synapse_user_id.clone())?
610+
.to_owned();
611+
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
612+
if true || is_likely_appservice(&username) {
613+
// HACK can we do anything better
614+
continue;
615+
}
616+
return Err(Error::MissingUserFromDependentTable {
617+
table: "access_tokens".to_owned(),
618+
user: synapse_user_id,
619+
});
620+
};
621+
622+
if user_infos.flags.is_deactivated() || user_infos.flags.is_guest() {
606623
continue;
607624
}
608-
return Err(Error::MissingUserFromDependentTable {
609-
table: "access_tokens".to_owned(),
610-
user: synapse_user_id,
611-
});
612-
};
613625

614-
if user_infos.flags.is_deactivated() || user_infos.flags.is_guest() {
615-
continue;
616-
}
626+
// It's not always accurate, but last_validated is *often* the creation time of
627+
// the device If we don't have one, then use the current time as a
628+
// fallback.
629+
let created_at = last_validated.map_or_else(|| now, DateTime::from);
630+
631+
let session_id = if let Some(device_id) = device_id {
632+
// Use the existing device_id if this is the second token for a device
633+
*state
634+
.devices_to_compat_sessions
635+
.entry((user_infos.mas_user_id, CompactString::new(&device_id)))
636+
.or_insert_with(|| {
637+
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng))
638+
})
639+
} else {
640+
// If this is a deviceless access token, create a deviceless compat session
641+
// for it (since otherwise we won't create one whilst migrating devices)
642+
let deviceless_session_id =
643+
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
644+
645+
deviceless_session_write_buffer
646+
.write(
647+
&mut mas,
648+
MasNewCompatSession {
649+
session_id: deviceless_session_id,
650+
user_id: user_infos.mas_user_id,
651+
device_id: None,
652+
human_name: None,
653+
created_at,
654+
is_synapse_admin: false,
655+
last_active_at: None,
656+
last_active_ip: None,
657+
user_agent: None,
658+
},
659+
)
660+
.await
661+
.into_mas("failed to write deviceless compat sessions")?;
617662

618-
// It's not always accurate, but last_validated is *often* the creation time of
619-
// the device If we don't have one, then use the current time as a
620-
// fallback.
621-
let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from);
663+
deviceless_session_id
664+
};
622665

623-
let session_id = if let Some(device_id) = device_id {
624-
// Use the existing device_id if this is the second token for a device
625-
*state
626-
.devices_to_compat_sessions
627-
.entry((user_infos.mas_user_id, CompactString::new(&device_id)))
628-
.or_insert_with(|| {
629-
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng))
630-
})
631-
} else {
632-
// If this is a deviceless access token, create a deviceless compat session
633-
// for it (since otherwise we won't create one whilst migrating devices)
634-
let deviceless_session_id =
635-
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
666+
let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
636667

637-
deviceless_session_write_buffer
668+
write_buffer
638669
.write(
639-
mas,
640-
MasNewCompatSession {
641-
session_id: deviceless_session_id,
642-
user_id: user_infos.mas_user_id,
643-
device_id: None,
644-
human_name: None,
670+
&mut mas,
671+
MasNewCompatAccessToken {
672+
token_id,
673+
session_id,
674+
access_token: token,
645675
created_at,
646-
is_synapse_admin: false,
647-
last_active_at: None,
648-
last_active_ip: None,
649-
user_agent: None,
676+
expires_at: valid_until_ms.map(DateTime::from),
650677
},
651678
)
652679
.await
653-
.into_mas("failed to write deviceless compat sessions")?;
654-
655-
deviceless_session_id
656-
};
657-
658-
let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
659-
680+
.into_mas("writing compat access tokens")?;
681+
}
660682
write_buffer
661-
.write(
662-
mas,
663-
MasNewCompatAccessToken {
664-
token_id,
665-
session_id,
666-
access_token: token,
667-
created_at,
668-
expires_at: valid_until_ms.map(DateTime::from),
669-
},
670-
)
683+
.finish(&mut mas)
671684
.await
672685
.into_mas("writing compat access tokens")?;
673-
}
686+
deviceless_session_write_buffer
687+
.finish(&mut mas)
688+
.await
689+
.into_mas("writing deviceless compat sessions")?;
674690

675-
write_buffer
676-
.finish(mas)
677-
.await
678-
.into_mas("writing compat access tokens")?;
679-
deviceless_session_write_buffer
680-
.finish(mas)
681-
.await
682-
.into_mas("writing deviceless compat sessions")?;
691+
Ok((state, mas))
692+
});
693+
694+
synapse
695+
.read_unrefreshable_access_tokens()
696+
.with_progress_bar(count_hint, 10_000)
697+
.map_err(|e| e.into_synapse("reading tokens"))
698+
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
699+
.await?;
700+
701+
let (state, mas) = task.await.expect("task crashed")?;
683702

684703
info!(
685704
"non-refreshable access tokens migrated in {:.1}s",
686705
Instant::now().duration_since(start).as_secs_f64()
687706
);
688707

689-
Ok(())
708+
Ok((state, mas))
690709
}
691710

692711
/// Migrates (access token, refresh token) pairs.

0 commit comments

Comments
 (0)