Skip to content

Commit b310ffc

Browse files
committed
Pipe stream directly in the channel
1 parent ccacd2a commit b310ffc

File tree

3 files changed

+143
-115
lines changed

3 files changed

+143
-115
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/syn2mas/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ serde.workspace = true
1818
thiserror.workspace = true
1919
thiserror-ext.workspace = true
2020
tokio.workspace = true
21+
tokio-util.workspace = true
2122
sqlx.workspace = true
2223
chrono.workspace = true
2324
compact_str.workspace = true

crates/syn2mas/src/migration.rs

Lines changed: 141 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ use std::{pin::pin, time::Instant};
2020

2121
use chrono::{DateTime, Utc};
2222
use compact_str::CompactString;
23-
use futures_util::StreamExt as _;
23+
use futures_util::{SinkExt, StreamExt as _, TryStreamExt as _};
2424
use mas_storage::Clock;
2525
use rand::{RngCore, SeedableRng};
2626
use thiserror::Error;
2727
use thiserror_ext::ContextInto;
28+
use tokio_util::sync::PollSender;
2829
use tracing::{info, Level, Span};
2930
use tracing_indicatif::span_ext::IndicatifSpanExt;
3031
use ulid::Ulid;
@@ -61,6 +62,8 @@ pub enum Error {
6162
source: ExtractLocalpartError,
6263
user: FullUserId,
6364
},
65+
#[error("channel closed")]
66+
ChannelClosed,
6467
#[error("user {user} was not found for migration but a row in {table} was found for them")]
6568
MissingUserFromDependentTable { table: String, user: FullUserId },
6669
#[error("missing a mapping for the auth provider with ID {synapse_id:?} (used by {user} and maybe other users)")]
@@ -121,6 +124,7 @@ pub async fn migrate(
121124
rng,
122125
)
123126
.await?;
127+
let user_localparts_to_uuid = migrated_users.user_localparts_to_uuid;
124128

125129
span.pb_set_message("migrating threepids");
126130
span.pb_inc(1);
@@ -133,7 +137,7 @@ pub async fn migrate(
133137
.expect("More than u64::MAX threepids — unable to handle this many!"),
134138
server_name,
135139
rng,
136-
&migrated_users.user_localparts_to_uuid,
140+
&user_localparts_to_uuid,
137141
)
138142
.await?;
139143

@@ -148,13 +152,13 @@ pub async fn migrate(
148152
.expect("More than u64::MAX external IDs — unable to handle this many!"),
149153
server_name,
150154
rng,
151-
&migrated_users.user_localparts_to_uuid,
155+
&user_localparts_to_uuid,
152156
provider_id_mapping,
153157
)
154158
.await?;
155159

156160
// `(MAS user_id, device_id)` mapped to `compat_session` ULID
157-
let mut devices_to_compat_sessions: HashMap<(Uuid, CompactString), Uuid> =
161+
let devices_to_compat_sessions: HashMap<(Uuid, CompactString), Uuid> =
158162
HashMap::with_capacity_and_hasher(
159163
usize::try_from(counts.devices)
160164
.expect("More than usize::MAX devices — unable to handle this many!")
@@ -166,20 +170,21 @@ pub async fn migrate(
166170

167171
span.pb_set_message("migrating access tokens");
168172
span.pb_inc(1);
169-
migrate_unrefreshable_access_tokens(
170-
&mut synapse,
171-
&mut mas,
172-
counts
173-
.access_tokens
174-
.try_into()
175-
.expect("More than u64::MAX access tokens — unable to handle this many!"),
176-
server_name,
177-
clock,
178-
rng,
179-
&migrated_users.user_localparts_to_uuid,
180-
&mut devices_to_compat_sessions,
181-
)
182-
.await?;
173+
let (mut mas, user_localparts_to_uuid, mut devices_to_compat_sessions) =
174+
migrate_unrefreshable_access_tokens(
175+
&mut synapse,
176+
mas,
177+
counts
178+
.access_tokens
179+
.try_into()
180+
.expect("More than u64::MAX access tokens — unable to handle this many!"),
181+
server_name,
182+
clock,
183+
rng,
184+
user_localparts_to_uuid,
185+
devices_to_compat_sessions,
186+
)
187+
.await?;
183188

184189
span.pb_set_message("migrating refresh tokens");
185190
span.pb_inc(1);
@@ -193,7 +198,7 @@ pub async fn migrate(
193198
server_name,
194199
clock,
195200
rng,
196-
&migrated_users.user_localparts_to_uuid,
201+
&user_localparts_to_uuid,
197202
&mut devices_to_compat_sessions,
198203
)
199204
.await?;
@@ -209,7 +214,7 @@ pub async fn migrate(
209214
.expect("More than u64::MAX devices — unable to handle this many!"),
210215
server_name,
211216
rng,
212-
&migrated_users.user_localparts_to_uuid,
217+
&user_localparts_to_uuid,
213218
&mut devices_to_compat_sessions,
214219
&migrated_users.synapse_admins,
215220
)
@@ -244,9 +249,6 @@ async fn migrate_users(
244249
) -> Result<(UsersMigrated, MasWriter), Error> {
245250
let start = Instant::now();
246251

247-
let mut users_stream = pin!(synapse
248-
.read_users()
249-
.with_progress_bar(user_count_hint as u64, 10_000));
250252
let (tx, mut rx) = tokio::sync::mpsc::channel(1024 * 1024);
251253

252254
let mut rng = rand_chacha::ChaCha8Rng::from_rng(rng).expect("failed to seed rng");
@@ -299,10 +301,12 @@ async fn migrate_users(
299301
Ok((synapse_admins, user_localparts_to_uuid, mas))
300302
});
301303

302-
while let Some(user_res) = users_stream.next().await {
303-
let user = user_res.into_synapse("reading user")?;
304-
tx.send(user).await.expect("failed to send in channel");
305-
}
304+
synapse
305+
.read_users()
306+
.with_progress_bar(user_count_hint as u64, 10_000)
307+
.map_err(|e| e.into_synapse("reading users"))
308+
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
309+
.await?;
306310

307311
let (synapse_admins, user_localparts_to_uuid, mas) = task.await.expect("task panicked")?;
308312

@@ -608,122 +612,144 @@ async fn migrate_devices(
608612
/// Migrates unrefreshable access tokens (those without an associated refresh
609613
/// token). Some of these may be deviceless.
610614
#[tracing::instrument(skip_all, fields(indicatif.pb_show), level = Level::INFO)]
611-
#[allow(clippy::too_many_arguments)]
615+
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
612616
async fn migrate_unrefreshable_access_tokens(
613617
synapse: &mut SynapseReader<'_>,
614-
mas: &mut MasWriter,
618+
mut mas: MasWriter,
615619
count_hint: u64,
616620
server_name: &str,
617621
clock: &dyn Clock,
618622
rng: &mut impl RngCore,
619-
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
620-
devices: &mut HashMap<(Uuid, CompactString), Uuid>,
621-
) -> Result<(), Error> {
623+
user_localparts_to_uuid: HashMap<CompactString, Uuid>,
624+
mut devices: HashMap<(Uuid, CompactString), Uuid>,
625+
) -> Result<
626+
(
627+
MasWriter,
628+
HashMap<CompactString, Uuid>,
629+
HashMap<(Uuid, CompactString), Uuid>,
630+
),
631+
Error,
632+
> {
622633
let start = Instant::now();
623634

624-
let mut token_stream = pin!(synapse
625-
.read_unrefreshable_access_tokens()
626-
.with_progress_bar(count_hint, 10_000));
627-
let mut write_buffer = MasWriteBuffer::new(mas, MasWriter::write_compat_access_tokens);
628-
let mut deviceless_session_write_buffer =
629-
MasWriteBuffer::new(mas, MasWriter::write_compat_sessions);
630-
631-
while let Some(token_res) = token_stream.next().await {
632-
let SynapseAccessToken {
633-
user_id: synapse_user_id,
634-
device_id,
635-
token,
636-
valid_until_ms,
637-
last_validated,
638-
} = token_res.into_synapse("reading Synapse access token")?;
635+
let (tx, mut rx) = tokio::sync::mpsc::channel(1024 * 1024);
639636

640-
let username = synapse_user_id
641-
.extract_localpart(server_name)
642-
.into_extract_localpart(synapse_user_id.clone())?
643-
.to_owned();
644-
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
645-
if true || is_likely_appservice(&username) {
646-
// HACK can we do anything better
647-
continue;
648-
}
649-
return Err(Error::MissingUserFromDependentTable {
650-
table: "access_tokens".to_owned(),
651-
user: synapse_user_id,
652-
});
653-
};
637+
let now = clock.now();
638+
let server_name = server_name.to_owned();
639+
let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
640+
let task = tokio::spawn(async move {
641+
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
642+
let mut deviceless_session_write_buffer =
643+
MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
644+
645+
while let Some(token) = rx.recv().await {
646+
let SynapseAccessToken {
647+
user_id: synapse_user_id,
648+
device_id,
649+
token,
650+
valid_until_ms,
651+
last_validated,
652+
} = token;
653+
654+
let username = synapse_user_id
655+
.extract_localpart(&server_name)
656+
.into_extract_localpart(synapse_user_id.clone())?
657+
.to_owned();
658+
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
659+
if true || is_likely_appservice(&username) {
660+
// HACK can we do anything better
661+
continue;
662+
}
663+
return Err(Error::MissingUserFromDependentTable {
664+
table: "access_tokens".to_owned(),
665+
user: synapse_user_id,
666+
});
667+
};
668+
669+
// It's not always accurate, but last_validated is *often* the creation time of
670+
// the device If we don't have one, then use the current time as a
671+
// fallback.
672+
let created_at = last_validated.map_or_else(|| now, DateTime::from);
673+
674+
let session_id = if let Some(device_id) = device_id {
675+
// Use the existing device_id if this is the second token for a device
676+
*devices
677+
.entry((user_id, CompactString::new(&device_id)))
678+
.or_insert_with(|| {
679+
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng))
680+
})
681+
} else {
682+
// If this is a deviceless access token, create a deviceless compat session
683+
// for it (since otherwise we won't create one whilst migrating devices)
684+
let deviceless_session_id =
685+
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
686+
687+
deviceless_session_write_buffer
688+
.write(
689+
&mut mas,
690+
MasNewCompatSession {
691+
session_id: deviceless_session_id,
692+
user_id,
693+
device_id: None,
694+
human_name: None,
695+
created_at,
696+
is_synapse_admin: false,
697+
last_active_at: None,
698+
last_active_ip: None,
699+
user_agent: None,
700+
},
701+
)
702+
.await
703+
.into_mas("failed to write deviceless compat sessions")?;
654704

655-
// It's not always accurate, but last_validated is *often* the creation time of
656-
// the device If we don't have one, then use the current time as a
657-
// fallback.
658-
let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from);
705+
deviceless_session_id
706+
};
659707

660-
let session_id = if let Some(device_id) = device_id {
661-
// Use the existing device_id if this is the second token for a device
662-
*devices
663-
.entry((user_id, CompactString::new(&device_id)))
664-
.or_insert_with(|| {
665-
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng))
666-
})
667-
} else {
668-
// If this is a deviceless access token, create a deviceless compat session
669-
// for it (since otherwise we won't create one whilst migrating devices)
670-
let deviceless_session_id =
671-
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
708+
let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
672709

673-
deviceless_session_write_buffer
710+
// TODO skip access tokens for deactivated users
711+
write_buffer
674712
.write(
675-
mas,
676-
MasNewCompatSession {
677-
session_id: deviceless_session_id,
678-
user_id,
679-
device_id: None,
680-
human_name: None,
713+
&mut mas,
714+
MasNewCompatAccessToken {
715+
token_id,
716+
session_id,
717+
access_token: token,
681718
created_at,
682-
is_synapse_admin: false,
683-
last_active_at: None,
684-
last_active_ip: None,
685-
user_agent: None,
719+
expires_at: valid_until_ms.map(DateTime::from),
686720
},
687721
)
688722
.await
689-
.into_mas("failed to write deviceless compat sessions")?;
690-
691-
deviceless_session_id
692-
};
693-
694-
let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
723+
.into_mas("writing compat access tokens")?;
724+
}
695725

696-
// TODO skip access tokens for deactivated users
697726
write_buffer
698-
.write(
699-
mas,
700-
MasNewCompatAccessToken {
701-
token_id,
702-
session_id,
703-
access_token: token,
704-
created_at,
705-
expires_at: valid_until_ms.map(DateTime::from),
706-
},
707-
)
727+
.finish(&mut mas)
708728
.await
709729
.into_mas("writing compat access tokens")?;
710-
}
730+
deviceless_session_write_buffer
731+
.finish(&mut mas)
732+
.await
733+
.into_mas("writing deviceless compat sessions")?;
711734

712-
write_buffer
713-
.finish(mas)
714-
.await
715-
.into_mas("writing compat access tokens")?;
716-
deviceless_session_write_buffer
717-
.finish(mas)
718-
.await
719-
.into_mas("writing deviceless compat sessions")?;
735+
Ok((mas, user_localparts_to_uuid, devices))
736+
});
737+
738+
synapse
739+
.read_unrefreshable_access_tokens()
740+
.with_progress_bar(count_hint, 10_000)
741+
.map_err(|e| e.into_synapse("reading tokens"))
742+
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
743+
.await?;
744+
745+
let (mas, user_localparts_to_uuid, devices) = task.await.expect("task crashed")?;
720746

721747
info!(
722748
"non-refreshable access tokens migrated in {:.1}s",
723749
Instant::now().duration_since(start).as_secs_f64()
724750
);
725751

726-
Ok(())
752+
Ok((mas, user_localparts_to_uuid, devices))
727753
}
728754

729755
/// Migrates (access token, refresh token) pairs.

0 commit comments

Comments
 (0)