@@ -241,6 +241,46 @@ impl Config {
241241 }
242242 }
243243
244+ let databases = self
245+ . databases
246+ . iter ( )
247+ . map ( |database| database. name . clone ( ) )
248+ . collect :: < HashSet < _ > > ( ) ;
249+
250+ // Automatically configure system catalogs
251+ // as omnisharded.
252+ if self . general . system_catalogs_omnisharded {
253+ for database in databases {
254+ let entry = tables. entry ( database) . or_insert_with ( Vec :: new) ;
255+
256+ for table in [
257+ "pg_class" ,
258+ "pg_attribute" ,
259+ "pg_attrdef" ,
260+ "pg_index" ,
261+ "pg_constraint" ,
262+ "pg_namespace" ,
263+ "pg_database" ,
264+ "pg_tablespace" ,
265+ "pg_type" ,
266+ "pg_proc" ,
267+ "pg_operator" ,
268+ "pg_cast" ,
269+ "pg_enum" ,
270+ "pg_range" ,
271+ "pg_authid" ,
272+ "pg_am" ,
273+ ] {
274+ if entry. iter ( ) . find ( |t| t. name == table) . is_none ( ) {
275+ entry. push ( OmnishardedTable {
276+ name : table. to_string ( ) ,
277+ sticky_routing : true ,
278+ } ) ;
279+ }
280+ }
281+ }
282+ }
283+
244284 tables
245285 }
246286
@@ -772,4 +812,66 @@ tables = ["table_x"]
772812 assert_eq ! ( db2_tables[ 0 ] . name, "table_x" ) ;
773813 assert ! ( !db2_tables[ 0 ] . sticky_routing) ;
774814 }
815+
816+ #[ test]
817+ fn test_omnisharded_tables_system_catalogs ( ) {
818+ // Test with system_catalogs_omnisharded = true
819+ let source_enabled = r#"
820+ [general]
821+ host = "0.0.0.0"
822+ port = 6432
823+ system_catalogs_omnisharded = true
824+
825+ [[databases]]
826+ name = "db1"
827+ host = "127.0.0.1"
828+ port = 5432
829+
830+ [[omnisharded_tables]]
831+ database = "db1"
832+ tables = ["my_table"]
833+ "# ;
834+
835+ let config: Config = toml:: from_str ( source_enabled) . unwrap ( ) ;
836+ let tables = config. omnisharded_tables ( ) ;
837+ let db1_tables = tables. get ( "db1" ) . unwrap ( ) ;
838+
839+ // Should include my_table plus system catalogs
840+ assert ! ( db1_tables. iter( ) . any( |t| t. name == "my_table" ) ) ;
841+ assert ! ( db1_tables. iter( ) . any( |t| t. name == "pg_class" ) ) ;
842+ assert ! ( db1_tables. iter( ) . any( |t| t. name == "pg_attribute" ) ) ;
843+ assert ! ( db1_tables. iter( ) . any( |t| t. name == "pg_namespace" ) ) ;
844+ assert ! ( db1_tables. iter( ) . any( |t| t. name == "pg_type" ) ) ;
845+
846+ // System catalogs should have sticky_routing = true
847+ let pg_class = db1_tables. iter ( ) . find ( |t| t. name == "pg_class" ) . unwrap ( ) ;
848+ assert ! ( pg_class. sticky_routing) ;
849+
850+ // Test with system_catalogs_omnisharded = false
851+ let source_disabled = r#"
852+ [general]
853+ host = "0.0.0.0"
854+ port = 6432
855+ system_catalogs_omnisharded = false
856+
857+ [[databases]]
858+ name = "db1"
859+ host = "127.0.0.1"
860+ port = 5432
861+
862+ [[omnisharded_tables]]
863+ database = "db1"
864+ tables = ["my_table"]
865+ "# ;
866+
867+ let config: Config = toml:: from_str ( source_disabled) . unwrap ( ) ;
868+ let tables = config. omnisharded_tables ( ) ;
869+ let db1_tables = tables. get ( "db1" ) . unwrap ( ) ;
870+
871+ // Should only include my_table, no system catalogs
872+ assert_eq ! ( db1_tables. len( ) , 1 ) ;
873+ assert_eq ! ( db1_tables[ 0 ] . name, "my_table" ) ;
874+ assert ! ( !db1_tables. iter( ) . any( |t| t. name == "pg_class" ) ) ;
875+ assert ! ( !db1_tables. iter( ) . any( |t| t. name == "pg_attribute" ) ) ;
876+ }
775877}
0 commit comments