diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 99268d1..2ed146d 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -214,9 +214,10 @@ impl PgTypesData { #[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord)] enum OidCacheKey { - Schema(String), + Catalog(String), + Schema(String, String), /// Table by schema and table name - Table(String, String), + Table(String, String, String), } // Create custom schema provider for pg_catalog @@ -262,7 +263,11 @@ impl SchemaProvider for PgCatalogSchemaProvider { ))) } PG_CATALOG_TABLE_PG_DATABASE => { - let table = Arc::new(PgDatabaseTable::new(self.catalog_list.clone())); + let table = Arc::new(PgDatabaseTable::new( + self.catalog_list.clone(), + self.oid_counter.clone(), + self.oid_cache.clone(), + )); Ok(Some(Arc::new( StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), ))) @@ -289,7 +294,7 @@ impl PgCatalogSchemaProvider { pub fn new(catalog_list: Arc) -> PgCatalogSchemaProvider { Self { catalog_list, - oid_counter: Arc::new(AtomicU32::new(0)), + oid_counter: Arc::new(AtomicU32::new(16384)), oid_cache: Arc::new(RwLock::new(HashMap::new())), } } @@ -1156,10 +1161,19 @@ impl PgClassTable { // Iterate through all catalogs and schemas for catalog_name in this.catalog_list.catalog_names() { + let cache_key = OidCacheKey::Catalog(catalog_name.clone()); + let catalog_oid = if let Some(oid) = oid_cache.get(&cache_key) { + *oid + } else { + this.oid_counter.fetch_add(1, Ordering::Relaxed) + }; + swap_cache.insert(cache_key, catalog_oid); + if let Some(catalog) = this.catalog_list.catalog(&catalog_name) { for schema_name in catalog.schema_names() { if let Some(schema) = catalog.schema(&schema_name) { - let cache_key = OidCacheKey::Schema(schema_name.clone()); + let cache_key = + OidCacheKey::Schema(catalog_name.clone(), schema_name.clone()); let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) { *oid } else { @@ -1172,8 +1186,11 @@ impl PgClassTable { // Now process all tables in this schema for table_name in schema.table_names() { - let cache_key = - OidCacheKey::Table(schema_name.clone(), table_name.clone()); + let cache_key = OidCacheKey::Table( + catalog_name.clone(), + schema_name.clone(), + table_name.clone(), + ); let table_oid = if let Some(oid) = oid_cache.get(&cache_key) { *oid } else { @@ -1334,7 +1351,7 @@ impl PgNamespaceTable { for catalog_name in this.catalog_list.catalog_names() { if let Some(catalog) = this.catalog_list.catalog(&catalog_name) { for schema_name in catalog.schema_names() { - let cache_key = OidCacheKey::Schema(schema_name.clone()); + let cache_key = OidCacheKey::Schema(catalog_name.clone(), schema_name.clone()); let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) { *oid } else { @@ -1353,10 +1370,10 @@ impl PgNamespaceTable { // remove all schema cache and table of the schema which is no longer exists oid_cache.retain(|key, _| match key { - OidCacheKey::Schema(_) => false, - OidCacheKey::Table(schema_name, _) => { - schema_oid_cache.contains_key(&OidCacheKey::Schema(schema_name.clone())) - } + OidCacheKey::Catalog(..) => true, + OidCacheKey::Schema(..) => false, + OidCacheKey::Table(catalog, schema_name, _) => schema_oid_cache + .contains_key(&OidCacheKey::Schema(catalog.clone(), schema_name.clone())), }); // add new schema cache oid_cache.extend(schema_oid_cache); @@ -1391,14 +1408,20 @@ impl PartitionStream for PgNamespaceTable { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct PgDatabaseTable { schema: SchemaRef, catalog_list: Arc, + oid_counter: Arc, + oid_cache: Arc>>, } impl PgDatabaseTable { - pub fn new(catalog_list: Arc) -> Self { + pub fn new( + catalog_list: Arc, + oid_counter: Arc, + oid_cache: Arc>>, + ) -> Self { // Define the schema for pg_database // This matches PostgreSQL's pg_database table columns let schema = Arc::new(Schema::new(vec![ @@ -1421,14 +1444,13 @@ impl PgDatabaseTable { Self { schema, catalog_list, + oid_counter, + oid_cache, } } /// Generate record batches based on the current state of the catalog - async fn get_data( - schema: SchemaRef, - catalog_list: Arc, - ) -> Result { + async fn get_data(this: PgDatabaseTable) -> Result { // Vectors to store column data let mut oids = Vec::new(); let mut datnames = Vec::new(); @@ -1445,15 +1467,22 @@ impl PgDatabaseTable { let mut dattablespaces = Vec::new(); let mut datacles: Vec> = Vec::new(); - // Start OID counter (this is simplistic and would need to be more robust in practice) - let mut next_oid = 16384; // Standard PostgreSQL starting OID for user databases + // to store all schema-oid mapping temporarily before adding to global oid cache + let mut catalog_oid_cache = HashMap::new(); - // Add a record for each catalog (treating catalogs as "databases") - for catalog_name in catalog_list.catalog_names() { - let oid = next_oid; - next_oid += 1; + let mut oid_cache = this.oid_cache.write().await; - oids.push(oid); + // Add a record for each catalog (treating catalogs as "databases") + for catalog_name in this.catalog_list.catalog_names() { + let cache_key = OidCacheKey::Catalog(catalog_name.clone()); + let catalog_oid = if let Some(oid) = oid_cache.get(&cache_key) { + *oid + } else { + this.oid_counter.fetch_add(1, Ordering::Relaxed) + }; + catalog_oid_cache.insert(cache_key, catalog_oid); + + oids.push(catalog_oid as i32); datnames.push(catalog_name.clone()); datdbas.push(10); // Default owner (assuming 10 = postgres user) encodings.push(6); // 6 = UTF8 in PostgreSQL @@ -1471,11 +1500,18 @@ impl PgDatabaseTable { // Always include a "postgres" database entry if not already present // (This is for compatibility with tools that expect it) - if !datnames.contains(&"postgres".to_string()) { - let oid = next_oid; - - oids.push(oid); - datnames.push("postgres".to_string()); + let default_datname = "postgres".to_string(); + if !datnames.contains(&default_datname) { + let cache_key = OidCacheKey::Catalog(default_datname.clone()); + let catalog_oid = if let Some(oid) = oid_cache.get(&cache_key) { + *oid + } else { + this.oid_counter.fetch_add(1, Ordering::Relaxed) + }; + catalog_oid_cache.insert(cache_key, catalog_oid); + + oids.push(catalog_oid as i32); + datnames.push(default_datname); datdbas.push(10); encodings.push(6); datcollates.push("en_US.UTF-8".to_string()); @@ -1509,7 +1545,22 @@ impl PgDatabaseTable { ]; // Create a full record batch - let full_batch = RecordBatch::try_new(schema.clone(), arrays)?; + let full_batch = RecordBatch::try_new(this.schema.clone(), arrays)?; + + // update cache + // remove all schema cache and table of the schema which is no longer exists + oid_cache.retain(|key, _| match key { + OidCacheKey::Catalog(..) => false, + OidCacheKey::Schema(catalog, ..) => { + catalog_oid_cache.contains_key(&OidCacheKey::Catalog(catalog.clone())) + } + OidCacheKey::Table(catalog, ..) => { + catalog_oid_cache.contains_key(&OidCacheKey::Catalog(catalog.clone())) + } + }); + // add new schema cache + oid_cache.extend(catalog_oid_cache); + Ok(full_batch) } } @@ -1520,11 +1571,10 @@ impl PartitionStream for PgDatabaseTable { } fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - let catalog_list = self.catalog_list.clone(); - let schema = Arc::clone(&self.schema); + let this = self.clone(); Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), - futures::stream::once(async move { Self::get_data(schema, catalog_list).await }), + this.schema.clone(), + futures::stream::once(async move { Self::get_data(this).await }), )) } }