diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 2dde9d6..99268d1 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use async_trait::async_trait; @@ -16,6 +18,8 @@ use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; use datafusion::prelude::{create_udf, SessionContext}; +use postgres_types::Oid; +use tokio::sync::RwLock; const PG_CATALOG_TABLE_PG_TYPE: &str = "pg_type"; const PG_CATALOG_TABLE_PG_CLASS: &str = "pg_class"; @@ -208,10 +212,19 @@ impl PgTypesData { } } +#[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord)] +enum OidCacheKey { + Schema(String), + /// Table by schema and table name + Table(String, String), +} + // Create custom schema provider for pg_catalog #[derive(Debug)] pub struct PgCatalogSchemaProvider { catalog_list: Arc, + oid_counter: Arc, + oid_cache: Arc>>, } #[async_trait] @@ -229,13 +242,21 @@ impl SchemaProvider for PgCatalogSchemaProvider { PG_CATALOG_TABLE_PG_TYPE => Ok(Some(self.create_pg_type_table())), PG_CATALOG_TABLE_PG_AM => Ok(Some(self.create_pg_am_table())), PG_CATALOG_TABLE_PG_CLASS => { - let table = Arc::new(PgClassTable::new(self.catalog_list.clone())); + let table = Arc::new(PgClassTable::new( + self.catalog_list.clone(), + self.oid_counter.clone(), + self.oid_cache.clone(), + )); Ok(Some(Arc::new( StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), ))) } PG_CATALOG_TABLE_PG_NAMESPACE => { - let table = Arc::new(PgNamespaceTable::new(self.catalog_list.clone())); + let table = Arc::new(PgNamespaceTable::new( + self.catalog_list.clone(), + self.oid_counter.clone(), + self.oid_cache.clone(), + )); Ok(Some(Arc::new( StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), ))) @@ -266,7 +287,11 @@ impl SchemaProvider for PgCatalogSchemaProvider { impl PgCatalogSchemaProvider { pub fn new(catalog_list: Arc) -> PgCatalogSchemaProvider { - Self { catalog_list } + Self { + catalog_list, + oid_counter: Arc::new(AtomicU32::new(0)), + oid_cache: Arc::new(RwLock::new(HashMap::new())), + } } /// Create a populated pg_type table with standard PostgreSQL data types @@ -1033,14 +1058,20 @@ impl PgProcData { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct PgClassTable { schema: SchemaRef, catalog_list: Arc, + oid_counter: Arc, + oid_cache: Arc>>, } impl PgClassTable { - fn new(catalog_list: Arc) -> PgClassTable { + fn new( + catalog_list: Arc, + oid_counter: Arc, + oid_cache: Arc>>, + ) -> PgClassTable { // Define the schema for pg_class // This matches key columns from PostgreSQL's pg_class let schema = Arc::new(Schema::new(vec![ @@ -1079,14 +1110,13 @@ impl PgClassTable { Self { schema, catalog_list, + oid_counter, + oid_cache, } } /// Generate record batches based on the current state of the catalog - async fn get_data( - schema: SchemaRef, - catalog_list: Arc, - ) -> Result { + async fn get_data(this: PgClassTable) -> Result { // Vectors to store column data let mut oids = Vec::new(); let mut relnames = Vec::new(); @@ -1119,24 +1149,37 @@ impl PgClassTable { let mut relfrozenxids = Vec::new(); let mut relminmxids = Vec::new(); - // Start OID counter (this is simplistic and would need to be more robust in practice) - let mut next_oid = 10000; + let mut oid_cache = this.oid_cache.write().await; + // Every time when call pg_catalog we generate a new cache and drop the + // original one in case that schemas or tables were dropped. + let mut swap_cache = HashMap::new(); // Iterate through all catalogs and schemas - for catalog_name in catalog_list.catalog_names() { - if let Some(catalog) = catalog_list.catalog(&catalog_name) { + for catalog_name in this.catalog_list.catalog_names() { + if let Some(catalog) = this.catalog_list.catalog(&catalog_name) { for schema_name in catalog.schema_names() { if let Some(schema) = catalog.schema(&schema_name) { - let schema_oid = next_oid; - next_oid += 1; + let cache_key = OidCacheKey::Schema(schema_name.clone()); + let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) { + *oid + } else { + this.oid_counter.fetch_add(1, Ordering::Relaxed) + }; + swap_cache.insert(cache_key, schema_oid); // Add an entry for the schema itself (as a namespace) // (In a full implementation, this would go in pg_namespace) // Now process all tables in this schema for table_name in schema.table_names() { - let table_oid = next_oid; - next_oid += 1; + let cache_key = + OidCacheKey::Table(schema_name.clone(), table_name.clone()); + let table_oid = if let Some(oid) = oid_cache.get(&cache_key) { + *oid + } else { + this.oid_counter.fetch_add(1, Ordering::Relaxed) + }; + swap_cache.insert(cache_key, table_oid); if let Some(table) = schema.table(&table_name).await? { // Determine the correct table type based on the table provider and context @@ -1147,14 +1190,14 @@ impl PgClassTable { let column_count = table.schema().fields().len() as i16; // Add table entry - oids.push(table_oid); + oids.push(table_oid as i32); relnames.push(table_name.clone()); - relnamespaces.push(schema_oid); + relnamespaces.push(schema_oid as i32); reltypes.push(0); // Simplified: we're not tracking data types reloftypes.push(None); relowners.push(0); // Simplified: no owner tracking relams.push(0); // Default access method - relfilenodes.push(table_oid); // Use OID as filenode + relfilenodes.push(table_oid as i32); // Use OID as filenode reltablespaces.push(0); // Default tablespace relpages.push(1); // Default page count reltuples.push(0.0); // No row count stats @@ -1184,6 +1227,8 @@ impl PgClassTable { } } + *oid_cache = swap_cache; + // Create Arrow arrays from the collected data let arrays: Vec = vec![ Arc::new(Int32Array::from(oids)), @@ -1219,7 +1264,7 @@ impl PgClassTable { ]; // Create a record batch - let batch = RecordBatch::try_new(schema.clone(), arrays)?; + let batch = RecordBatch::try_new(this.schema.clone(), arrays)?; Ok(batch) } @@ -1231,23 +1276,28 @@ impl PartitionStream for PgClassTable { } fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - let catalog_list = self.catalog_list.clone(); - let schema = Arc::clone(&self.schema); + let this = self.clone(); Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), - futures::stream::once(async move { Self::get_data(schema, catalog_list).await }), + this.schema.clone(), + futures::stream::once(async move { PgClassTable::get_data(this).await }), )) } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct PgNamespaceTable { schema: SchemaRef, catalog_list: Arc, + oid_counter: Arc, + oid_cache: Arc>>, } impl PgNamespaceTable { - pub fn new(catalog_list: Arc) -> Self { + pub fn new( + catalog_list: Arc, + oid_counter: Arc, + oid_cache: Arc>>, + ) -> Self { // Define the schema for pg_namespace // This matches the columns from PostgreSQL's pg_namespace let schema = Arc::new(Schema::new(vec![ @@ -1261,14 +1311,13 @@ impl PgNamespaceTable { Self { schema, catalog_list, + oid_counter, + oid_cache, } } /// Generate record batches based on the current state of the catalog - async fn get_data( - schema: SchemaRef, - catalog_list: Arc, - ) -> Result { + async fn get_data(this: PgNamespaceTable) -> Result { // Vectors to store column data let mut oids = Vec::new(); let mut nspnames = Vec::new(); @@ -1276,47 +1325,24 @@ impl PgNamespaceTable { let mut nspacls: Vec> = Vec::new(); let mut options: Vec> = Vec::new(); - // Start OID counter (should be consistent with the values used in pg_class) - let mut next_oid = 10000; + // to store all schema-oid mapping temporarily before adding to global oid cache + let mut schema_oid_cache = HashMap::new(); - // Add standard PostgreSQL system schemas - // pg_catalog schema (OID 11) - oids.push(11); - nspnames.push("pg_catalog".to_string()); - nspowners.push(10); // Default superuser - nspacls.push(None); - options.push(None); - - // public schema (OID 2200) - oids.push(2200); - nspnames.push("public".to_string()); - nspowners.push(10); // Default superuser - nspacls.push(None); - options.push(None); - - // information_schema (OID 12) - oids.push(12); - nspnames.push("information_schema".to_string()); - nspowners.push(10); // Default superuser - nspacls.push(None); - options.push(None); + let mut oid_cache = this.oid_cache.write().await; // Now add all schemas from DataFusion catalogs - for catalog_name in catalog_list.catalog_names() { - if let Some(catalog) = catalog_list.catalog(&catalog_name) { + for catalog_name in this.catalog_list.catalog_names() { + if let Some(catalog) = this.catalog_list.catalog(&catalog_name) { for schema_name in catalog.schema_names() { - // Skip schemas we've already added as system schemas - if schema_name == "pg_catalog" - || schema_name == "public" - || schema_name == "information_schema" - { - continue; - } - - let schema_oid = next_oid; - next_oid += 1; - - oids.push(schema_oid); + let cache_key = OidCacheKey::Schema(schema_name.clone()); + let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) { + *oid + } else { + this.oid_counter.fetch_add(1, Ordering::Relaxed) + }; + schema_oid_cache.insert(cache_key, schema_oid); + + oids.push(schema_oid as i32); nspnames.push(schema_name.clone()); nspowners.push(10); // Default owner nspacls.push(None); @@ -1325,6 +1351,16 @@ impl PgNamespaceTable { } } + // remove all schema cache and table of the schema which is no longer exists + oid_cache.retain(|key, _| match key { + OidCacheKey::Schema(_) => false, + OidCacheKey::Table(schema_name, _) => { + schema_oid_cache.contains_key(&OidCacheKey::Schema(schema_name.clone())) + } + }); + // add new schema cache + oid_cache.extend(schema_oid_cache); + // Create Arrow arrays from the collected data let arrays: Vec = vec![ Arc::new(Int32Array::from(oids)), @@ -1335,7 +1371,7 @@ impl PgNamespaceTable { ]; // Create a full record batch - let batch = RecordBatch::try_new(schema.clone(), arrays)?; + let batch = RecordBatch::try_new(this.schema.clone(), arrays)?; Ok(batch) } @@ -1347,11 +1383,10 @@ impl PartitionStream for PgNamespaceTable { } fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - let catalog_list = self.catalog_list.clone(); - let schema = Arc::clone(&self.schema); + let this = self.clone(); Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), - futures::stream::once(async move { Self::get_data(schema, catalog_list).await }), + this.schema.clone(), + futures::stream::once(async move { Self::get_data(this).await }), )) } }