11use native_tls:: TlsConnector ;
22use postgres_native_tls:: MakeTlsConnector ;
3- use std:: collections:: { BTreeMap , BTreeSet } ;
4- use std:: sync:: Arc ;
5- use tokio_postgres:: Client ;
3+ use std:: {
4+ collections:: { BTreeMap , BTreeSet } ,
5+ error:: Error ,
6+ sync:: Arc ,
7+ } ;
8+ use tokio_postgres:: {
9+ types:: { FromSql , Type } ,
10+ Client ,
11+ } ;
612
13+ use crate :: sql_entities:: View ;
714use crate :: {
815 sql_entities:: {
916 ColumnConstraints , ForeignKey , SqlERData , SqlERDataLoader , SqlEnums , Table , TableColumn ,
@@ -12,7 +19,7 @@ use crate::{
1219} ;
1320
1421static GET_TABLES_LIST_QUERY : & str = r#"
15- SELECT trim(both '"' from table_name) as table_name
22+ SELECT trim(both '"' from table_name) as table_name, table_type
1623FROM information_schema.tables
1724WHERE table_schema = $1
1825ORDER BY table_name;
@@ -188,7 +195,6 @@ impl PostgreSqlERDLoader {
188195 . query ( GET_COLUMNS_BASIC_INFO_QUERY , & [ & table_names] )
189196 . await ?;
190197 for row in rows {
191- // I don't know how to get rid this
192198 let col_num: i16 = row. get ( "col_num" ) ;
193199 let col_name: & str = row. get ( "col_name" ) ;
194200 let not_null: bool = row. get ( "not_null" ) ;
@@ -345,6 +351,27 @@ impl PostgreSqlERDLoader {
345351 }
346352}
347353
354+ #[ derive( Debug , PartialEq ) ]
355+ enum TableType {
356+ BaseTable ,
357+ View ,
358+ }
359+
360+ impl < ' a > FromSql < ' a > for TableType {
361+ fn from_sql ( _ty : & Type , raw : & [ u8 ] ) -> Result < Self , Box < dyn Error + Sync + Send > > {
362+ let s = std:: str:: from_utf8 ( raw) ?;
363+ match s {
364+ "BASE TABLE" => Ok ( TableType :: BaseTable ) ,
365+ "VIEW" => Ok ( TableType :: View ) ,
366+ other => Err ( format ! ( "Unknown table type: {}" , other) . into ( ) ) ,
367+ }
368+ }
369+
370+ fn accepts ( ty : & Type ) -> bool {
371+ * ty == Type :: TEXT || * ty == Type :: VARCHAR
372+ }
373+ }
374+
348375#[ async_trait:: async_trait]
349376impl SqlERDataLoader for PostgreSqlERDLoader {
350377 async fn load_erd_data ( & mut self ) -> Result < SqlERData , crate :: SqlantError > {
@@ -362,14 +389,51 @@ impl SqlERDataLoader for PostgreSqlERDLoader {
362389 . client
363390 . query ( GET_TABLES_LIST_QUERY , & [ & self . schema_name ] )
364391 . await ?;
365- let table_names: Vec < String > = res. iter ( ) . map ( |row| row. get ( "table_name" ) ) . collect ( ) ;
366- let ( tables, enums) = self . load_tables ( table_names) . await ?;
367- let foreign_keys = self . get_fks ( & tables) ?;
392+
393+ // Collect table names and types as a vector of tuples
394+ let table_names_with_types: Vec < ( String , TableType ) > = res
395+ . iter ( )
396+ . map ( |row| ( row. get ( "table_name" ) , row. get ( "table_type" ) ) )
397+ . collect ( ) ;
398+
399+ // Extract just the table names for loading
400+ let table_names: Vec < String > = table_names_with_types
401+ . iter ( )
402+ . map ( |( name, _) | name. clone ( ) )
403+ . collect ( ) ;
404+
405+ let ( tables_and_views, enums) = self . load_tables ( table_names) . await ?;
406+ let foreign_keys = self . get_fks ( & tables_and_views) ?;
407+
408+ let mut views: Vec < Arc < View > > = vec ! [ ] ;
409+ let mut tables: Vec < Arc < Table > > = vec ! [ ] ;
410+
411+ for entity in tables_and_views. into_iter ( ) {
412+ let ( _, r#type) = table_names_with_types
413+ . iter ( )
414+ . find ( |t| t. 0 == entity. name )
415+ . unwrap ( ) ;
416+ match r#type {
417+ TableType :: BaseTable => tables. push ( entity) ,
418+ TableType :: View => {
419+ let Table { name, columns, .. } = Arc :: try_unwrap ( entity) . unwrap ( ) ;
420+ views. push (
421+ View {
422+ materizlied : false ,
423+ name,
424+ columns,
425+ }
426+ . into ( ) ,
427+ ) ;
428+ }
429+ }
430+ }
368431
369432 Ok ( SqlERData {
370433 tables,
371434 foreign_keys,
372435 enums,
436+ views,
373437 } )
374438 }
375439}
0 commit comments