@@ -667,5 +667,176 @@ impl<'writer, 'conn> MasUserWriteBuffer<'writer, 'conn> {
667
667
668
668
#[ cfg( test) ]
669
669
mod test {
670
- // TODO test me
670
+ use std:: collections:: { BTreeMap , BTreeSet } ;
671
+
672
+ use chrono:: DateTime ;
673
+ use futures_util:: TryStreamExt ;
674
+
675
+ use serde:: Serialize ;
676
+ use sqlx:: { Column , PgConnection , PgPool , Row } ;
677
+ use uuid:: Uuid ;
678
+
679
+ use crate :: {
680
+ mas_writer:: { MasNewUser , MasNewUserPassword } ,
681
+ LockedMasDatabase , MasWriter ,
682
+ } ;
683
+
684
+ /// A snapshot of a whole database
685
+ #[ derive( Default , Serialize ) ]
686
+ #[ serde( transparent) ]
687
+ struct DatabaseSnapshot {
688
+ tables : BTreeMap < String , TableSnapshot > ,
689
+ }
690
+
691
+ #[ derive( Serialize ) ]
692
+ #[ serde( transparent) ]
693
+ struct TableSnapshot {
694
+ rows : BTreeSet < RowSnapshot > ,
695
+ }
696
+
697
+ #[ derive( PartialEq , Eq , PartialOrd , Ord , Serialize ) ]
698
+ #[ serde( transparent) ]
699
+ struct RowSnapshot {
700
+ columns_to_values : BTreeMap < String , Option < String > > ,
701
+ }
702
+
703
+ const SKIPPED_TABLES : & [ & str ] = & [ "_sqlx_migrations" ] ;
704
+
705
+ /// Produces a serialisable snapshot of a database, usable for snapshot testing
706
+ ///
707
+ /// For brevity, empty tables, as well as [`SKIPPED_TABLES`], will not be included in the snapshot.
708
+ async fn snapshot_database ( conn : & mut PgConnection ) -> DatabaseSnapshot {
709
+ let mut out = DatabaseSnapshot :: default ( ) ;
710
+ let table_names: Vec < String > = sqlx:: query_scalar (
711
+ "SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema();" ,
712
+ )
713
+ . fetch_all ( & mut * conn)
714
+ . await
715
+ . unwrap ( ) ;
716
+
717
+ for table_name in table_names {
718
+ if SKIPPED_TABLES . contains ( & table_name. as_str ( ) ) {
719
+ continue ;
720
+ }
721
+
722
+ let column_names: Vec < String > = sqlx:: query_scalar (
723
+ "SELECT column_name FROM information_schema.columns WHERE table_name = $1 AND table_schema = current_schema();"
724
+ ) . bind ( & table_name) . fetch_all ( & mut * conn) . await . expect ( "failed to get column names for table for snapshotting" ) ;
725
+
726
+ let column_name_list = column_names
727
+ . iter ( )
728
+ // stringify all the values for simplicity
729
+ . map ( |column_name| format ! ( "{column_name}::TEXT AS \" {column_name}\" " ) )
730
+ . collect :: < Vec < _ > > ( )
731
+ . join ( ", " ) ;
732
+
733
+ let table_rows = sqlx:: query ( & format ! ( "SELECT {column_name_list} FROM {table_name};" ) )
734
+ . fetch ( & mut * conn)
735
+ . map_ok ( |row| {
736
+ let mut columns_to_values = BTreeMap :: new ( ) ;
737
+ for ( idx, column) in row. columns ( ) . iter ( ) . enumerate ( ) {
738
+ columns_to_values. insert ( column. name ( ) . to_owned ( ) , row. get ( idx) ) ;
739
+ }
740
+ RowSnapshot { columns_to_values }
741
+ } )
742
+ . try_collect :: < BTreeSet < RowSnapshot > > ( )
743
+ . await
744
+ . expect ( "failed to fetch rows from table for snapshotting" ) ;
745
+
746
+ if !table_rows. is_empty ( ) {
747
+ out. tables
748
+ . insert ( table_name, TableSnapshot { rows : table_rows } ) ;
749
+ }
750
+ }
751
+
752
+ out
753
+ }
754
+
755
+ /// Make a snapshot assertion against the database.
756
+ macro_rules! assert_db_snapshot {
757
+ ( $db: expr) => {
758
+ let db_snapshot = snapshot_database( $db) . await ;
759
+ :: insta:: assert_yaml_snapshot!( db_snapshot) ;
760
+ } ;
761
+ }
762
+
763
+ /// Runs some code with a `MasWriter`.
764
+ ///
765
+ /// The callback is responsible for `finish`ing the `MasWriter`.
766
+ async fn make_mas_writer < ' conn > (
767
+ pool : & PgPool ,
768
+ main_conn : & ' conn mut PgConnection ,
769
+ ) -> MasWriter < ' conn > {
770
+ let mut writer_conns = Vec :: new ( ) ;
771
+ for _ in 0 ..2 {
772
+ writer_conns. push (
773
+ pool. acquire ( )
774
+ . await
775
+ . expect ( "failed to acquire MasWriter writer connection" )
776
+ . detach ( ) ,
777
+ ) ;
778
+ }
779
+ let locked_main_conn = LockedMasDatabase :: try_new ( main_conn)
780
+ . await
781
+ . expect ( "failed to lock MAS database" )
782
+ . expect_left ( "MAS database is already locked" ) ;
783
+ MasWriter :: new ( locked_main_conn, writer_conns)
784
+ . await
785
+ . expect ( "failed to construct MasWriter" )
786
+ }
787
+
788
+ /// Tests writing a single user, without a password.
789
+ #[ sqlx:: test( migrator = "mas_storage_pg::MIGRATOR" ) ]
790
+ async fn test_write_user ( pool : PgPool ) {
791
+ let mut conn = pool. acquire ( ) . await . unwrap ( ) ;
792
+ let mut writer = make_mas_writer ( & pool, & mut conn) . await ;
793
+
794
+ writer
795
+ . write_users ( vec ! [ MasNewUser {
796
+ user_id: Uuid :: from_u128( 1u128 ) ,
797
+ username: "alice" . to_owned( ) ,
798
+ created_at: DateTime :: default ( ) ,
799
+ locked_at: None ,
800
+ can_request_admin: false ,
801
+ } ] )
802
+ . await
803
+ . expect ( "failed to write user" ) ;
804
+
805
+ writer. finish ( ) . await . expect ( "failed to finish MasWriter" ) ;
806
+
807
+ assert_db_snapshot ! ( & mut conn) ;
808
+ }
809
+
810
+ /// Tests writing a single user, with a password.
811
+ #[ sqlx:: test( migrator = "mas_storage_pg::MIGRATOR" ) ]
812
+ async fn test_write_user_with_password ( pool : PgPool ) {
813
+ const USER_ID : Uuid = Uuid :: from_u128 ( 1u128 ) ;
814
+
815
+ let mut conn = pool. acquire ( ) . await . unwrap ( ) ;
816
+ let mut writer = make_mas_writer ( & pool, & mut conn) . await ;
817
+
818
+ writer
819
+ . write_users ( vec ! [ MasNewUser {
820
+ user_id: USER_ID ,
821
+ username: "alice" . to_owned( ) ,
822
+ created_at: DateTime :: default ( ) ,
823
+ locked_at: None ,
824
+ can_request_admin: false ,
825
+ } ] )
826
+ . await
827
+ . expect ( "failed to write user" ) ;
828
+ writer
829
+ . write_passwords ( vec ! [ MasNewUserPassword {
830
+ user_password_id: Uuid :: from_u128( 42u128 ) ,
831
+ user_id: USER_ID ,
832
+ hashed_password: "$bcrypt$aaaaaaaaaaa" . to_owned( ) ,
833
+ created_at: DateTime :: default ( ) ,
834
+ } ] )
835
+ . await
836
+ . expect ( "failed to write password" ) ;
837
+
838
+ writer. finish ( ) . await . expect ( "failed to finish MasWriter" ) ;
839
+
840
+ assert_db_snapshot ! ( & mut conn) ;
841
+ }
671
842
}
0 commit comments