10
10
use std:: fmt:: Display ;
11
11
12
12
use chrono:: { DateTime , Utc } ;
13
- use futures_util:: { future:: BoxFuture , TryStreamExt } ;
13
+ use futures_util:: { future:: BoxFuture , FutureExt , TryStreamExt } ;
14
14
use sqlx:: { query, query_as, Executor , PgConnection } ;
15
15
use thiserror:: Error ;
16
16
use thiserror_ext:: { Construct , ContextInto } ;
@@ -222,6 +222,14 @@ pub struct MasNewUnsupportedThreepid {
222
222
pub created_at : DateTime < Utc > ,
223
223
}
224
224
225
+ pub struct MasNewUpstreamOauthLink {
226
+ pub link_id : Uuid ,
227
+ pub user_id : Uuid ,
228
+ pub upstream_provider_id : Uuid ,
229
+ pub subject : String ,
230
+ pub created_at : DateTime < Utc > ,
231
+ }
232
+
225
233
/// The 'version' of the password hashing scheme used for passwords when they
226
234
/// are migrated from Synapse to MAS.
227
235
/// This is version 1, as in the previous syn2mas script.
@@ -234,6 +242,7 @@ pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
234
242
"user_passwords" ,
235
243
"user_emails" ,
236
244
"user_unsupported_third_party_ids" ,
245
+ "upstream_oauth_links" ,
237
246
] ;
238
247
239
248
/// Detect whether a syn2mas migration has started on the given database.
@@ -700,8 +709,6 @@ impl<'conn> MasWriter<'conn> {
700
709
created_ats. push ( created_at) ;
701
710
}
702
711
703
- // `confirmed_at` is going to get removed in a future MAS release,
704
- // so just populate with `created_at`
705
712
sqlx:: query!(
706
713
r#"
707
714
INSERT INTO syn2mas__user_unsupported_third_party_ids
@@ -718,6 +725,55 @@ impl<'conn> MasWriter<'conn> {
718
725
} )
719
726
} ) . await
720
727
}
728
+
729
+ #[ tracing:: instrument( skip_all, level = Level :: DEBUG ) ]
730
+ pub fn write_upstream_oauth_links (
731
+ & mut self ,
732
+ links : Vec < MasNewUpstreamOauthLink > ,
733
+ ) -> BoxFuture < ' _ , Result < ( ) , Error > > {
734
+ if links. is_empty ( ) {
735
+ return async { Ok ( ( ) ) } . boxed ( ) ;
736
+ }
737
+ self . writer_pool . spawn_with_connection ( move |conn| {
738
+ Box :: pin ( async move {
739
+ let mut link_ids: Vec < Uuid > = Vec :: with_capacity ( links. len ( ) ) ;
740
+ let mut user_ids: Vec < Uuid > = Vec :: with_capacity ( links. len ( ) ) ;
741
+ let mut upstream_provider_ids: Vec < Uuid > = Vec :: with_capacity ( links. len ( ) ) ;
742
+ let mut subjects: Vec < String > = Vec :: with_capacity ( links. len ( ) ) ;
743
+ let mut created_ats: Vec < DateTime < Utc > > = Vec :: with_capacity ( links. len ( ) ) ;
744
+
745
+ for MasNewUpstreamOauthLink {
746
+ link_id,
747
+ user_id,
748
+ upstream_provider_id,
749
+ subject,
750
+ created_at,
751
+ } in links
752
+ {
753
+ link_ids. push ( link_id) ;
754
+ user_ids. push ( user_id) ;
755
+ upstream_provider_ids. push ( upstream_provider_id) ;
756
+ subjects. push ( subject) ;
757
+ created_ats. push ( created_at) ;
758
+ }
759
+
760
+ sqlx:: query!(
761
+ r#"
762
+ INSERT INTO syn2mas__upstream_oauth_links
763
+ (upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
764
+ SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
765
+ "# ,
766
+ & link_ids[ ..] ,
767
+ & user_ids[ ..] ,
768
+ & upstream_provider_ids[ ..] ,
769
+ & subjects[ ..] ,
770
+ & created_ats[ ..] ,
771
+ ) . execute ( & mut * conn) . await . into_database ( "writing unsupported threepids to MAS" ) ?;
772
+
773
+ Ok ( ( ) )
774
+ } )
775
+ } ) . boxed ( )
776
+ }
721
777
}
722
778
723
779
// How many entries to buffer at once, before writing a batch of rows to the
@@ -727,6 +783,7 @@ impl<'conn> MasWriter<'conn> {
727
783
// stream to two tables at once...)
728
784
const WRITE_BUFFER_BATCH_SIZE : usize = 4096 ;
729
785
786
+ // TODO replace with just `MasWriteBuffer`
730
787
pub struct MasUserWriteBuffer < ' writer , ' conn > {
731
788
users : Vec < MasNewUser > ,
732
789
passwords : Vec < MasNewUserPassword > ,
@@ -786,6 +843,7 @@ impl<'writer, 'conn> MasUserWriteBuffer<'writer, 'conn> {
786
843
}
787
844
}
788
845
846
+ // TODO replace with just `MasWriteBuffer`
789
847
pub struct MasThreepidWriteBuffer < ' writer , ' conn > {
790
848
email : Vec < MasNewEmailThreepid > ,
791
849
unsupported : Vec < MasNewUnsupportedThreepid > ,
@@ -843,6 +901,60 @@ impl<'writer, 'conn> MasThreepidWriteBuffer<'writer, 'conn> {
843
901
}
844
902
}
845
903
904
+ /// A function that can accept and flush buffers from a `MasWriteBuffer`.
905
+ /// Intended uses are the methods on `MasWriter` such as `write_users`.
906
+ type WriteBufferFlusher < ' conn , T > =
907
+ for <' a > fn ( & ' a mut MasWriter < ' conn > , Vec < T > ) -> BoxFuture < ' a , Result < ( ) , Error > > ;
908
+
909
+ /// A buffer for writing rows to the MAS database.
910
+ /// Generic over the type of rows.
911
+ ///
912
+ /// # Panics
913
+ ///
914
+ /// Panics if dropped before `finish()` has been called.
915
+ pub struct MasWriteBuffer < ' conn , T > {
916
+ rows : Vec < T > ,
917
+ flusher : WriteBufferFlusher < ' conn , T > ,
918
+ finished : bool ,
919
+ }
920
+
921
+ impl < ' conn , T > MasWriteBuffer < ' conn , T > {
922
+ pub fn new ( flusher : WriteBufferFlusher < ' conn , T > ) -> Self {
923
+ MasWriteBuffer {
924
+ rows : Vec :: with_capacity ( WRITE_BUFFER_BATCH_SIZE ) ,
925
+ flusher,
926
+ finished : false ,
927
+ }
928
+ }
929
+
930
+ pub async fn finish ( mut self , writer : & mut MasWriter < ' conn > ) -> Result < ( ) , Error > {
931
+ self . finished = true ;
932
+ self . flush ( writer) . await ?;
933
+ Ok ( ( ) )
934
+ }
935
+
936
+ pub async fn flush ( & mut self , writer : & mut MasWriter < ' conn > ) -> Result < ( ) , Error > {
937
+ let rows = std:: mem:: take ( & mut self . rows ) ;
938
+ self . rows . reserve_exact ( WRITE_BUFFER_BATCH_SIZE ) ;
939
+ ( self . flusher ) ( writer, rows) . await ?;
940
+ Ok ( ( ) )
941
+ }
942
+
943
+ pub async fn write ( & mut self , writer : & mut MasWriter < ' conn > , row : T ) -> Result < ( ) , Error > {
944
+ self . rows . push ( row) ;
945
+ if self . rows . len ( ) >= WRITE_BUFFER_BATCH_SIZE {
946
+ self . flush ( writer) . await ?;
947
+ }
948
+ Ok ( ( ) )
949
+ }
950
+ }
951
+
952
+ impl < T > Drop for MasWriteBuffer < ' _ , T > {
953
+ fn drop ( & mut self ) {
954
+ assert ! ( self . finished, "MasWriteBuffer dropped but not finished!" ) ;
955
+ }
956
+ }
957
+
846
958
#[ cfg( test) ]
847
959
mod test {
848
960
use std:: collections:: { BTreeMap , BTreeSet } ;
@@ -855,7 +967,8 @@ mod test {
855
967
856
968
use crate :: {
857
969
mas_writer:: {
858
- MasNewEmailThreepid , MasNewUnsupportedThreepid , MasNewUser , MasNewUserPassword ,
970
+ MasNewEmailThreepid , MasNewUnsupportedThreepid , MasNewUpstreamOauthLink , MasNewUser ,
971
+ MasNewUserPassword ,
859
972
} ,
860
973
LockedMasDatabase , MasWriter ,
861
974
} ;
@@ -1085,4 +1198,39 @@ mod test {
1085
1198
1086
1199
assert_db_snapshot ! ( & mut conn) ;
1087
1200
}
1201
+
1202
+ /// Tests writing a single user, with a link to an upstream provider.
1203
+ /// There needs to be an upstream provider in the database already — in the
1204
+ /// real migration, this is done by running a provider sync first.
1205
+ #[ sqlx:: test( migrator = "mas_storage_pg::MIGRATOR" , fixtures( "upstream_provider" ) ) ]
1206
+ async fn test_write_user_with_upstream_provider_link ( pool : PgPool ) {
1207
+ let mut conn = pool. acquire ( ) . await . unwrap ( ) ;
1208
+ let mut writer = make_mas_writer ( & pool, & mut conn) . await ;
1209
+
1210
+ writer
1211
+ . write_users ( vec ! [ MasNewUser {
1212
+ user_id: Uuid :: from_u128( 1u128 ) ,
1213
+ username: "alice" . to_owned( ) ,
1214
+ created_at: DateTime :: default ( ) ,
1215
+ locked_at: None ,
1216
+ can_request_admin: false ,
1217
+ } ] )
1218
+ . await
1219
+ . expect ( "failed to write user" ) ;
1220
+
1221
+ writer
1222
+ . write_upstream_oauth_links ( vec ! [ MasNewUpstreamOauthLink {
1223
+ user_id: Uuid :: from_u128( 1u128 ) ,
1224
+ link_id: Uuid :: from_u128( 3u128 ) ,
1225
+ upstream_provider_id: Uuid :: from_u128( 4u128 ) ,
1226
+ subject: "12345.67890" . to_owned( ) ,
1227
+ created_at: DateTime :: default ( ) ,
1228
+ } ] )
1229
+ . await
1230
+ . expect ( "failed to write link" ) ;
1231
+
1232
+ writer. finish ( ) . await . expect ( "failed to finish MasWriter" ) ;
1233
+
1234
+ assert_db_snapshot ! ( & mut conn) ;
1235
+ }
1088
1236
}
0 commit comments