1010use std:: fmt:: Display ;
1111
1212use chrono:: { DateTime , Utc } ;
13- use futures_util:: { future:: BoxFuture , TryStreamExt } ;
13+ use futures_util:: { future:: BoxFuture , FutureExt , TryStreamExt } ;
1414use sqlx:: { query, query_as, Executor , PgConnection } ;
1515use thiserror:: Error ;
1616use thiserror_ext:: { Construct , ContextInto } ;
@@ -222,6 +222,14 @@ pub struct MasNewUnsupportedThreepid {
222222 pub created_at : DateTime < Utc > ,
223223}
224224
225+ pub struct MasNewUpstreamOauthLink {
226+ pub link_id : Uuid ,
227+ pub user_id : Uuid ,
228+ pub upstream_provider_id : Uuid ,
229+ pub subject : String ,
230+ pub created_at : DateTime < Utc > ,
231+ }
232+
225233/// The 'version' of the password hashing scheme used for passwords when they
226234/// are migrated from Synapse to MAS.
227235/// This is version 1, as in the previous syn2mas script.
@@ -234,6 +242,7 @@ pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
234242 "user_passwords" ,
235243 "user_emails" ,
236244 "user_unsupported_third_party_ids" ,
245+ "upstream_oauth_links" ,
237246] ;
238247
239248/// Detect whether a syn2mas migration has started on the given database.
@@ -700,8 +709,6 @@ impl<'conn> MasWriter<'conn> {
700709 created_ats. push ( created_at) ;
701710 }
702711
703- // `confirmed_at` is going to get removed in a future MAS release,
704- // so just populate with `created_at`
705712 sqlx:: query!(
706713 r#"
707714 INSERT INTO syn2mas__user_unsupported_third_party_ids
@@ -718,6 +725,55 @@ impl<'conn> MasWriter<'conn> {
718725 } )
719726 } ) . await
720727 }
728+
729+ #[ tracing:: instrument( skip_all, level = Level :: DEBUG ) ]
730+ pub fn write_upstream_oauth_links (
731+ & mut self ,
732+ links : Vec < MasNewUpstreamOauthLink > ,
733+ ) -> BoxFuture < ' _ , Result < ( ) , Error > > {
734+ if links. is_empty ( ) {
735+ return async { Ok ( ( ) ) } . boxed ( ) ;
736+ }
737+ self . writer_pool . spawn_with_connection ( move |conn| {
738+ Box :: pin ( async move {
739+ let mut link_ids: Vec < Uuid > = Vec :: with_capacity ( links. len ( ) ) ;
740+ let mut user_ids: Vec < Uuid > = Vec :: with_capacity ( links. len ( ) ) ;
741+ let mut upstream_provider_ids: Vec < Uuid > = Vec :: with_capacity ( links. len ( ) ) ;
742+ let mut subjects: Vec < String > = Vec :: with_capacity ( links. len ( ) ) ;
743+ let mut created_ats: Vec < DateTime < Utc > > = Vec :: with_capacity ( links. len ( ) ) ;
744+
745+ for MasNewUpstreamOauthLink {
746+ link_id,
747+ user_id,
748+ upstream_provider_id,
749+ subject,
750+ created_at,
751+ } in links
752+ {
753+ link_ids. push ( link_id) ;
754+ user_ids. push ( user_id) ;
755+ upstream_provider_ids. push ( upstream_provider_id) ;
756+ subjects. push ( subject) ;
757+ created_ats. push ( created_at) ;
758+ }
759+
760+ sqlx:: query!(
761+ r#"
762+ INSERT INTO syn2mas__upstream_oauth_links
763+ (upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
764+ SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
765+ "# ,
766+ & link_ids[ ..] ,
767+ & user_ids[ ..] ,
768+ & upstream_provider_ids[ ..] ,
769+ & subjects[ ..] ,
770+ & created_ats[ ..] ,
771+ ) . execute ( & mut * conn) . await . into_database ( "writing unsupported threepids to MAS" ) ?;
772+
773+ Ok ( ( ) )
774+ } )
775+ } ) . boxed ( )
776+ }
721777}
722778
723779// How many entries to buffer at once, before writing a batch of rows to the
@@ -727,6 +783,7 @@ impl<'conn> MasWriter<'conn> {
727783// stream to two tables at once...)
728784const WRITE_BUFFER_BATCH_SIZE : usize = 4096 ;
729785
786+ // TODO replace with just `MasWriteBuffer`
730787pub struct MasUserWriteBuffer < ' writer , ' conn > {
731788 users : Vec < MasNewUser > ,
732789 passwords : Vec < MasNewUserPassword > ,
@@ -786,6 +843,7 @@ impl<'writer, 'conn> MasUserWriteBuffer<'writer, 'conn> {
786843 }
787844}
788845
846+ // TODO replace with just `MasWriteBuffer`
789847pub struct MasThreepidWriteBuffer < ' writer , ' conn > {
790848 email : Vec < MasNewEmailThreepid > ,
791849 unsupported : Vec < MasNewUnsupportedThreepid > ,
@@ -843,6 +901,60 @@ impl<'writer, 'conn> MasThreepidWriteBuffer<'writer, 'conn> {
843901 }
844902}
845903
904+ /// A function that can accept and flush buffers from a `MasWriteBuffer`.
905+ /// Intended uses are the methods on `MasWriter` such as `write_users`.
906+ type WriteBufferFlusher < ' conn , T > =
907+ for <' a > fn ( & ' a mut MasWriter < ' conn > , Vec < T > ) -> BoxFuture < ' a , Result < ( ) , Error > > ;
908+
909+ /// A buffer for writing rows to the MAS database.
910+ /// Generic over the type of rows.
911+ ///
912+ /// # Panics
913+ ///
914+ /// Panics if dropped before `finish()` has been called.
915+ pub struct MasWriteBuffer < ' conn , T > {
916+ rows : Vec < T > ,
917+ flusher : WriteBufferFlusher < ' conn , T > ,
918+ finished : bool ,
919+ }
920+
921+ impl < ' conn , T > MasWriteBuffer < ' conn , T > {
922+ pub fn new ( flusher : WriteBufferFlusher < ' conn , T > ) -> Self {
923+ MasWriteBuffer {
924+ rows : Vec :: with_capacity ( WRITE_BUFFER_BATCH_SIZE ) ,
925+ flusher,
926+ finished : false ,
927+ }
928+ }
929+
930+ pub async fn finish ( mut self , writer : & mut MasWriter < ' conn > ) -> Result < ( ) , Error > {
931+ self . finished = true ;
932+ self . flush ( writer) . await ?;
933+ Ok ( ( ) )
934+ }
935+
936+ pub async fn flush ( & mut self , writer : & mut MasWriter < ' conn > ) -> Result < ( ) , Error > {
937+ let rows = std:: mem:: take ( & mut self . rows ) ;
938+ self . rows . reserve_exact ( WRITE_BUFFER_BATCH_SIZE ) ;
939+ ( self . flusher ) ( writer, rows) . await ?;
940+ Ok ( ( ) )
941+ }
942+
943+ pub async fn write ( & mut self , writer : & mut MasWriter < ' conn > , row : T ) -> Result < ( ) , Error > {
944+ self . rows . push ( row) ;
945+ if self . rows . len ( ) >= WRITE_BUFFER_BATCH_SIZE {
946+ self . flush ( writer) . await ?;
947+ }
948+ Ok ( ( ) )
949+ }
950+ }
951+
952+ impl < T > Drop for MasWriteBuffer < ' _ , T > {
953+ fn drop ( & mut self ) {
954+ assert ! ( self . finished, "MasWriteBuffer dropped but not finished!" ) ;
955+ }
956+ }
957+
846958#[ cfg( test) ]
847959mod test {
848960 use std:: collections:: { BTreeMap , BTreeSet } ;
@@ -855,7 +967,8 @@ mod test {
855967
856968 use crate :: {
857969 mas_writer:: {
858- MasNewEmailThreepid , MasNewUnsupportedThreepid , MasNewUser , MasNewUserPassword ,
970+ MasNewEmailThreepid , MasNewUnsupportedThreepid , MasNewUpstreamOauthLink , MasNewUser ,
971+ MasNewUserPassword ,
859972 } ,
860973 LockedMasDatabase , MasWriter ,
861974 } ;
@@ -1085,4 +1198,39 @@ mod test {
10851198
10861199 assert_db_snapshot ! ( & mut conn) ;
10871200 }
1201+
1202+ /// Tests writing a single user, with a link to an upstream provider.
1203+ /// There needs to be an upstream provider in the database already — in the
1204+ /// real migration, this is done by running a provider sync first.
1205+ #[ sqlx:: test( migrator = "mas_storage_pg::MIGRATOR" , fixtures( "upstream_provider" ) ) ]
1206+ async fn test_write_user_with_upstream_provider_link ( pool : PgPool ) {
1207+ let mut conn = pool. acquire ( ) . await . unwrap ( ) ;
1208+ let mut writer = make_mas_writer ( & pool, & mut conn) . await ;
1209+
1210+ writer
1211+ . write_users ( vec ! [ MasNewUser {
1212+ user_id: Uuid :: from_u128( 1u128 ) ,
1213+ username: "alice" . to_owned( ) ,
1214+ created_at: DateTime :: default ( ) ,
1215+ locked_at: None ,
1216+ can_request_admin: false ,
1217+ } ] )
1218+ . await
1219+ . expect ( "failed to write user" ) ;
1220+
1221+ writer
1222+ . write_upstream_oauth_links ( vec ! [ MasNewUpstreamOauthLink {
1223+ user_id: Uuid :: from_u128( 1u128 ) ,
1224+ link_id: Uuid :: from_u128( 3u128 ) ,
1225+ upstream_provider_id: Uuid :: from_u128( 4u128 ) ,
1226+ subject: "12345.67890" . to_owned( ) ,
1227+ created_at: DateTime :: default ( ) ,
1228+ } ] )
1229+ . await
1230+ . expect ( "failed to write link" ) ;
1231+
1232+ writer. finish ( ) . await . expect ( "failed to finish MasWriter" ) ;
1233+
1234+ assert_db_snapshot ! ( & mut conn) ;
1235+ }
10881236}
0 commit comments