@@ -20,11 +20,12 @@ use std::{pin::pin, time::Instant};
2020
2121use chrono:: { DateTime , Utc } ;
2222use compact_str:: CompactString ;
23- use futures_util:: StreamExt as _;
23+ use futures_util:: { SinkExt , StreamExt as _, TryStreamExt as _ } ;
2424use mas_storage:: Clock ;
2525use rand:: { RngCore , SeedableRng } ;
2626use thiserror:: Error ;
2727use thiserror_ext:: ContextInto ;
28+ use tokio_util:: sync:: PollSender ;
2829use tracing:: { info, Level , Span } ;
2930use tracing_indicatif:: span_ext:: IndicatifSpanExt ;
3031use ulid:: Ulid ;
@@ -61,6 +62,8 @@ pub enum Error {
6162 source : ExtractLocalpartError ,
6263 user : FullUserId ,
6364 } ,
65+ #[ error( "channel closed" ) ]
66+ ChannelClosed ,
6467 #[ error( "user {user} was not found for migration but a row in {table} was found for them" ) ]
6568 MissingUserFromDependentTable { table : String , user : FullUserId } ,
6669 #[ error( "missing a mapping for the auth provider with ID {synapse_id:?} (used by {user} and maybe other users)" ) ]
@@ -158,25 +161,27 @@ pub async fn migrate(
158161
159162 span. pb_set_message ( "migrating user rows" ) ;
160163 span. pb_inc ( 1 ) ;
161- let ( mut state, mut mas) = migrate_users ( & mut synapse, mas, counts. users , state, rng) . await ?;
164+ let ( state, mut mas) = migrate_users ( & mut synapse, mas, counts. users , state, rng) . await ?;
162165
163166 span. pb_set_message ( "migrating threepids" ) ;
164167 span. pb_inc ( 1 ) ;
165168 migrate_threepids ( & mut synapse, & mut mas, counts. threepids , rng, & state) . await ?;
166169 span. pb_set_message ( "migrating user external IDs" ) ;
167170 span. pb_inc ( 1 ) ;
168171 migrate_external_ids ( & mut synapse, & mut mas, counts. external_ids , rng, & state) . await ?;
172+
169173 span. pb_set_message ( "migrating access tokens" ) ;
170174 span. pb_inc ( 1 ) ;
171- migrate_unrefreshable_access_tokens (
175+ let ( mut state , mut mas ) = migrate_unrefreshable_access_tokens (
172176 & mut synapse,
173- & mut mas,
177+ mas,
174178 counts. access_tokens ,
175179 clock,
176180 rng,
177- & mut state,
181+ state,
178182 )
179183 . await ?;
184+
180185 span. pb_set_message ( "migrating refresh tokens" ) ;
181186 span. pb_inc ( 1 ) ;
182187 migrate_refreshable_token_pairs (
@@ -188,6 +193,7 @@ pub async fn migrate(
188193 & mut state,
189194 )
190195 . await ?;
196+
191197 span. pb_set_message ( "migrating devices" ) ;
192198 span. pb_inc ( 1 ) ;
193199 migrate_devices ( & mut synapse, & mut mas, counts. devices , rng, & mut state) . await ?;
@@ -221,7 +227,6 @@ async fn migrate_users(
221227) -> Result < ( MigrationState , MasWriter ) , Error > {
222228 let start = Instant :: now ( ) ;
223229
224- let mut users_stream = pin ! ( synapse. read_users( ) . with_progress_bar( count_hint, 10_000 ) ) ;
225230 let ( tx, mut rx) = tokio:: sync:: mpsc:: channel ( 1024 * 1024 ) ;
226231
227232 let mut rng = rand_chacha:: ChaCha8Rng :: from_rng ( rng) . expect ( "failed to seed rng" ) ;
@@ -268,10 +273,12 @@ async fn migrate_users(
268273 Ok ( ( state, mas) )
269274 } ) ;
270275
271- while let Some ( user_res) = users_stream. next ( ) . await {
272- let user = user_res. into_synapse ( "reading user" ) ?;
273- tx. send ( user) . await . expect ( "failed to send in channel" ) ;
274- }
276+ synapse
277+ . read_users ( )
278+ . with_progress_bar ( count_hint, 10_000 )
279+ . map_err ( |e| e. into_synapse ( "reading users" ) )
280+ . forward ( PollSender :: new ( tx) . sink_map_err ( |_| Error :: ChannelClosed ) )
281+ . await ?;
275282
276283 let ( state, mas) = task. await . expect ( "task panicked" ) ?;
277284
@@ -569,124 +576,136 @@ async fn migrate_devices(
569576/// Migrates unrefreshable access tokens (those without an associated refresh
570577/// token). Some of these may be deviceless.
571578#[ tracing:: instrument( skip_all, fields( indicatif. pb_show) , level = Level :: INFO ) ]
572- #[ allow( clippy:: too_many_arguments) ]
579+ #[ allow( clippy:: too_many_arguments, clippy :: type_complexity ) ]
573580async fn migrate_unrefreshable_access_tokens (
574581 synapse : & mut SynapseReader < ' _ > ,
575- mas : & mut MasWriter ,
582+ mut mas : MasWriter ,
576583 count_hint : usize ,
577584 clock : & dyn Clock ,
578585 rng : & mut impl RngCore ,
579- state : & mut MigrationState ,
580- ) -> Result < ( ) , Error > {
586+ mut state : MigrationState ,
587+ ) -> Result < ( MigrationState , MasWriter ) , Error > {
581588 let start = Instant :: now ( ) ;
582589
583- let mut token_stream = pin ! ( synapse
584- . read_unrefreshable_access_tokens( )
585- . with_progress_bar( count_hint, 10_000 ) ) ;
586- let mut write_buffer = MasWriteBuffer :: new ( mas, MasWriter :: write_compat_access_tokens) ;
587- let mut deviceless_session_write_buffer =
588- MasWriteBuffer :: new ( mas, MasWriter :: write_compat_sessions) ;
589-
590- while let Some ( token_res) = token_stream. next ( ) . await {
591- let SynapseAccessToken {
592- user_id : synapse_user_id,
593- device_id,
594- token,
595- valid_until_ms,
596- last_validated,
597- } = token_res. into_synapse ( "reading Synapse access token" ) ?;
590+ let ( tx, mut rx) = tokio:: sync:: mpsc:: channel ( 1024 * 1024 ) ;
598591
599- let username = synapse_user_id
600- . extract_localpart ( & state. server_name )
601- . into_extract_localpart ( synapse_user_id. clone ( ) ) ?
602- . to_owned ( ) ;
603- let Some ( user_infos) = state. users . get ( username. as_str ( ) ) . copied ( ) else {
604- if true || is_likely_appservice ( & username) {
605- // HACK can we do anything better
592+ let now = clock. now ( ) ;
593+ let mut rng = rand_chacha:: ChaChaRng :: from_rng ( rng) . expect ( "failed to seed rng" ) ;
594+ let task = tokio:: spawn ( async move {
595+ let mut write_buffer = MasWriteBuffer :: new ( & mas, MasWriter :: write_compat_access_tokens) ;
596+ let mut deviceless_session_write_buffer =
597+ MasWriteBuffer :: new ( & mas, MasWriter :: write_compat_sessions) ;
598+
599+ while let Some ( token) = rx. recv ( ) . await {
600+ let SynapseAccessToken {
601+ user_id : synapse_user_id,
602+ device_id,
603+ token,
604+ valid_until_ms,
605+ last_validated,
606+ } = token;
607+ let username = synapse_user_id
608+ . extract_localpart ( & state. server_name )
609+ . into_extract_localpart ( synapse_user_id. clone ( ) ) ?
610+ . to_owned ( ) ;
611+ let Some ( user_infos) = state. users . get ( username. as_str ( ) ) . copied ( ) else {
612+ if true || is_likely_appservice ( & username) {
613+ // HACK can we do anything better
614+ continue ;
615+ }
616+ return Err ( Error :: MissingUserFromDependentTable {
617+ table : "access_tokens" . to_owned ( ) ,
618+ user : synapse_user_id,
619+ } ) ;
620+ } ;
621+
622+ if user_infos. flags . is_deactivated ( ) || user_infos. flags . is_guest ( ) {
606623 continue ;
607624 }
608- return Err ( Error :: MissingUserFromDependentTable {
609- table : "access_tokens" . to_owned ( ) ,
610- user : synapse_user_id,
611- } ) ;
612- } ;
613625
614- if user_infos. flags . is_deactivated ( ) || user_infos. flags . is_guest ( ) {
615- continue ;
616- }
626+ // It's not always accurate, but last_validated is *often* the creation time of
627+ // the device If we don't have one, then use the current time as a
628+ // fallback.
629+ let created_at = last_validated. map_or_else ( || now, DateTime :: from) ;
630+
631+ let session_id = if let Some ( device_id) = device_id {
632+ // Use the existing device_id if this is the second token for a device
633+ * state
634+ . devices_to_compat_sessions
635+ . entry ( ( user_infos. mas_user_id , CompactString :: new ( & device_id) ) )
636+ . or_insert_with ( || {
637+ Uuid :: from ( Ulid :: from_datetime_with_source ( created_at. into ( ) , & mut rng) )
638+ } )
639+ } else {
640+ // If this is a deviceless access token, create a deviceless compat session
641+ // for it (since otherwise we won't create one whilst migrating devices)
642+ let deviceless_session_id =
643+ Uuid :: from ( Ulid :: from_datetime_with_source ( created_at. into ( ) , & mut rng) ) ;
644+
645+ deviceless_session_write_buffer
646+ . write (
647+ & mut mas,
648+ MasNewCompatSession {
649+ session_id : deviceless_session_id,
650+ user_id : user_infos. mas_user_id ,
651+ device_id : None ,
652+ human_name : None ,
653+ created_at,
654+ is_synapse_admin : false ,
655+ last_active_at : None ,
656+ last_active_ip : None ,
657+ user_agent : None ,
658+ } ,
659+ )
660+ . await
661+ . into_mas ( "failed to write deviceless compat sessions" ) ?;
617662
618- // It's not always accurate, but last_validated is *often* the creation time of
619- // the device If we don't have one, then use the current time as a
620- // fallback.
621- let created_at = last_validated. map_or_else ( || clock. now ( ) , DateTime :: from) ;
663+ deviceless_session_id
664+ } ;
622665
623- let session_id = if let Some ( device_id) = device_id {
624- // Use the existing device_id if this is the second token for a device
625- * state
626- . devices_to_compat_sessions
627- . entry ( ( user_infos. mas_user_id , CompactString :: new ( & device_id) ) )
628- . or_insert_with ( || {
629- Uuid :: from ( Ulid :: from_datetime_with_source ( created_at. into ( ) , rng) )
630- } )
631- } else {
632- // If this is a deviceless access token, create a deviceless compat session
633- // for it (since otherwise we won't create one whilst migrating devices)
634- let deviceless_session_id =
635- Uuid :: from ( Ulid :: from_datetime_with_source ( created_at. into ( ) , rng) ) ;
666+ let token_id = Uuid :: from ( Ulid :: from_datetime_with_source ( created_at. into ( ) , & mut rng) ) ;
636667
637- deviceless_session_write_buffer
668+ write_buffer
638669 . write (
639- mas,
640- MasNewCompatSession {
641- session_id : deviceless_session_id,
642- user_id : user_infos. mas_user_id ,
643- device_id : None ,
644- human_name : None ,
670+ & mut mas,
671+ MasNewCompatAccessToken {
672+ token_id,
673+ session_id,
674+ access_token : token,
645675 created_at,
646- is_synapse_admin : false ,
647- last_active_at : None ,
648- last_active_ip : None ,
649- user_agent : None ,
676+ expires_at : valid_until_ms. map ( DateTime :: from) ,
650677 } ,
651678 )
652679 . await
653- . into_mas ( "failed to write deviceless compat sessions" ) ?;
654-
655- deviceless_session_id
656- } ;
657-
658- let token_id = Uuid :: from ( Ulid :: from_datetime_with_source ( created_at. into ( ) , rng) ) ;
659-
680+ . into_mas ( "writing compat access tokens" ) ?;
681+ }
660682 write_buffer
661- . write (
662- mas,
663- MasNewCompatAccessToken {
664- token_id,
665- session_id,
666- access_token : token,
667- created_at,
668- expires_at : valid_until_ms. map ( DateTime :: from) ,
669- } ,
670- )
683+ . finish ( & mut mas)
671684 . await
672685 . into_mas ( "writing compat access tokens" ) ?;
673- }
686+ deviceless_session_write_buffer
687+ . finish ( & mut mas)
688+ . await
689+ . into_mas ( "writing deviceless compat sessions" ) ?;
674690
675- write_buffer
676- . finish ( mas)
677- . await
678- . into_mas ( "writing compat access tokens" ) ?;
679- deviceless_session_write_buffer
680- . finish ( mas)
681- . await
682- . into_mas ( "writing deviceless compat sessions" ) ?;
691+ Ok ( ( state, mas) )
692+ } ) ;
693+
694+ synapse
695+ . read_unrefreshable_access_tokens ( )
696+ . with_progress_bar ( count_hint, 10_000 )
697+ . map_err ( |e| e. into_synapse ( "reading tokens" ) )
698+ . forward ( PollSender :: new ( tx) . sink_map_err ( |_| Error :: ChannelClosed ) )
699+ . await ?;
700+
701+ let ( state, mas) = task. await . expect ( "task crashed" ) ?;
683702
684703 info ! (
685704 "non-refreshable access tokens migrated in {:.1}s" ,
686705 Instant :: now( ) . duration_since( start) . as_secs_f64( )
687706 ) ;
688707
689- Ok ( ( ) )
708+ Ok ( ( state , mas ) )
690709}
691710
692711/// Migrates (access token, refresh token) pairs.
0 commit comments