Skip to content

Commit a230f06

Browse files
authored
refactor: add a layer of abstraction for CatalogProviderList (#161)
1 parent 64de74b commit a230f06

File tree

7 files changed

+191
-108
lines changed

7 files changed

+191
-108
lines changed

datafusion-postgres/src/pg_catalog.rs

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@ use datafusion::arrow::array::{
99
use datafusion::arrow::datatypes::{DataType, Field, SchemaRef};
1010
use datafusion::arrow::ipc::reader::FileReader;
1111
use datafusion::catalog::streaming::StreamingTable;
12-
use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider, TableFunctionImpl};
12+
use datafusion::catalog::{MemTable, SchemaProvider, TableFunctionImpl};
1313
use datafusion::common::utils::SingleRowListArrayBuilder;
14-
use datafusion::datasource::{TableProvider, ViewTable};
14+
use datafusion::datasource::TableProvider;
1515
use datafusion::error::{DataFusionError, Result};
1616
use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility};
1717
use datafusion::physical_plan::streaming::PartitionStream;
1818
use datafusion::prelude::{create_udf, Expr, SessionContext};
1919
use postgres_types::Oid;
2020
use tokio::sync::RwLock;
2121

22+
use crate::pg_catalog::catalog_info::CatalogInfo;
23+
24+
pub mod catalog_info;
2225
pub mod empty_table;
2326
pub mod has_privilege_udf;
2427
pub mod pg_attribute;
@@ -97,37 +100,6 @@ const PG_CATALOG_VIEW_PG_MATVIEWS: &str = "pg_matviews";
97100
const PG_CATALOG_VIEW_PG_TABLES: &str = "pg_tables";
98101
const PG_CATALOG_VIEW_PG_STAT_USER_TABELS: &str = "pg_stat_user_tables";
99102

100-
/// Determine PostgreSQL table type (relkind) from DataFusion TableProvider
101-
fn get_table_type(table: &Arc<dyn TableProvider>) -> &'static str {
102-
// Use Any trait to determine the actual table provider type
103-
if table.as_any().is::<ViewTable>() {
104-
"v" // view
105-
} else {
106-
"r" // All other table types (StreamingTable, MemTable, etc.) are treated as regular tables
107-
}
108-
}
109-
110-
/// Determine PostgreSQL table type (relkind) with table name context
111-
fn get_table_type_with_name(
112-
table: &Arc<dyn TableProvider>,
113-
table_name: &str,
114-
schema_name: &str,
115-
) -> &'static str {
116-
// Check if this is a system catalog table
117-
if schema_name == "pg_catalog" || schema_name == "information_schema" {
118-
if table_name.starts_with("pg_")
119-
|| table_name.contains("_table")
120-
|| table_name.contains("_column")
121-
{
122-
"r" // System tables are still regular tables in PostgreSQL
123-
} else {
124-
"v" // Some system objects might be views
125-
}
126-
} else {
127-
get_table_type(table)
128-
}
129-
}
130-
131103
pub const PG_CATALOG_TABLES: &[&str] = &[
132104
PG_CATALOG_TABLE_PG_AGGREGATE,
133105
PG_CATALOG_TABLE_PG_AM,
@@ -206,15 +178,15 @@ pub(crate) enum OidCacheKey {
206178

207179
// Create custom schema provider for pg_catalog
208180
#[derive(Debug)]
209-
pub struct PgCatalogSchemaProvider {
210-
catalog_list: Arc<dyn CatalogProviderList>,
181+
pub struct PgCatalogSchemaProvider<C> {
182+
catalog_list: C,
211183
oid_counter: Arc<AtomicU32>,
212184
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
213185
static_tables: Arc<PgCatalogStaticTables>,
214186
}
215187

216188
#[async_trait]
217-
impl SchemaProvider for PgCatalogSchemaProvider {
189+
impl<C: CatalogInfo> SchemaProvider for PgCatalogSchemaProvider<C> {
218190
fn as_any(&self) -> &dyn std::any::Any {
219191
self
220192
}
@@ -389,11 +361,11 @@ impl SchemaProvider for PgCatalogSchemaProvider {
389361
}
390362
}
391363

392-
impl PgCatalogSchemaProvider {
364+
impl<C: CatalogInfo> PgCatalogSchemaProvider<C> {
393365
pub fn try_new(
394-
catalog_list: Arc<dyn CatalogProviderList>,
366+
catalog_list: C,
395367
static_tables: Arc<PgCatalogStaticTables>,
396-
) -> Result<PgCatalogSchemaProvider> {
368+
) -> Result<PgCatalogSchemaProvider<C>> {
397369
Ok(Self {
398370
catalog_list,
399371
oid_counter: Arc::new(AtomicU32::new(16384)),
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
use std::fmt::Debug;
2+
use std::sync::Arc;
3+
4+
use async_trait::async_trait;
5+
use datafusion::{
6+
arrow::datatypes::SchemaRef, catalog::CatalogProviderList, datasource::TableType,
7+
error::DataFusionError,
8+
};
9+
10+
/// Define the interface for retrieve catalog data for pg_catalog tables
11+
#[async_trait]
12+
pub trait CatalogInfo: Clone + Send + Sync + Debug + 'static {
13+
fn catalog_names(&self) -> Vec<String>;
14+
15+
fn schema_names(&self, catalog_name: &str) -> Option<Vec<String>>;
16+
17+
fn table_names(&self, catalog_name: &str, schema_name: &str) -> Option<Vec<String>>;
18+
19+
async fn table_schema(
20+
&self,
21+
catalog_name: &str,
22+
schema_name: &str,
23+
table_name: &str,
24+
) -> Result<Option<SchemaRef>, DataFusionError>;
25+
26+
async fn table_type(
27+
&self,
28+
catalog_name: &str,
29+
schema_name: &str,
30+
table_name: &str,
31+
) -> Result<Option<TableType>, DataFusionError>;
32+
}
33+
34+
#[async_trait]
35+
impl CatalogInfo for Arc<dyn CatalogProviderList> {
36+
fn catalog_names(&self) -> Vec<String> {
37+
CatalogProviderList::catalog_names(self.as_ref())
38+
}
39+
40+
fn schema_names(&self, catalog_name: &str) -> Option<Vec<String>> {
41+
self.catalog(catalog_name).map(|c| c.schema_names())
42+
}
43+
44+
fn table_names(&self, catalog_name: &str, schema_name: &str) -> Option<Vec<String>> {
45+
self.catalog(catalog_name)
46+
.and_then(|c| c.schema(schema_name))
47+
.map(|s| s.table_names())
48+
}
49+
50+
async fn table_schema(
51+
&self,
52+
catalog_name: &str,
53+
schema_name: &str,
54+
table_name: &str,
55+
) -> Result<Option<SchemaRef>, DataFusionError> {
56+
let schema = self
57+
.catalog(catalog_name)
58+
.and_then(|c| c.schema(schema_name));
59+
if let Some(schema) = schema {
60+
let table_schema = schema.table(table_name).await?.map(|t| t.schema());
61+
Ok(table_schema)
62+
} else {
63+
Ok(None)
64+
}
65+
}
66+
67+
async fn table_type(
68+
&self,
69+
catalog_name: &str,
70+
schema_name: &str,
71+
table_name: &str,
72+
) -> Result<Option<TableType>, DataFusionError> {
73+
let schema = self
74+
.catalog(catalog_name)
75+
.and_then(|c| c.schema(schema_name));
76+
if let Some(schema) = schema {
77+
let table_type = schema.table_type(table_name).await?;
78+
Ok(table_type)
79+
} else {
80+
Ok(None)
81+
}
82+
}
83+
}
84+
85+
pub fn table_type_to_string(tt: &TableType) -> String {
86+
match tt {
87+
TableType::Base => "r".to_string(),
88+
TableType::View => "v".to_string(),
89+
TableType::Temporary => "r".to_string(),
90+
}
91+
}

datafusion-postgres/src/pg_catalog/pg_attribute.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,28 @@ use datafusion::arrow::array::{
66
ArrayRef, BooleanArray, Int16Array, Int32Array, RecordBatch, StringArray,
77
};
88
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
9-
use datafusion::catalog::CatalogProviderList;
109
use datafusion::error::Result;
1110
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
1211
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
1312
use datafusion::physical_plan::streaming::PartitionStream;
1413
use postgres_types::Oid;
1514
use tokio::sync::RwLock;
1615

16+
use crate::pg_catalog::catalog_info::CatalogInfo;
17+
1718
use super::OidCacheKey;
1819

1920
#[derive(Debug, Clone)]
20-
pub(crate) struct PgAttributeTable {
21+
pub(crate) struct PgAttributeTable<C> {
2122
schema: SchemaRef,
22-
catalog_list: Arc<dyn CatalogProviderList>,
23+
catalog_list: C,
2324
oid_counter: Arc<AtomicU32>,
2425
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
2526
}
2627

27-
impl PgAttributeTable {
28+
impl<C: CatalogInfo> PgAttributeTable<C> {
2829
pub(crate) fn new(
29-
catalog_list: Arc<dyn CatalogProviderList>,
30+
catalog_list: C,
3031
oid_counter: Arc<AtomicU32>,
3132
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
3233
) -> Self {
@@ -105,11 +106,13 @@ impl PgAttributeTable {
105106
let mut swap_cache = HashMap::new();
106107

107108
for catalog_name in this.catalog_list.catalog_names() {
108-
if let Some(catalog) = this.catalog_list.catalog(&catalog_name) {
109-
for schema_name in catalog.schema_names() {
110-
if let Some(schema_provider) = catalog.schema(&schema_name) {
109+
if let Some(schema_names) = this.catalog_list.schema_names(&catalog_name) {
110+
for schema_name in schema_names {
111+
if let Some(table_names) =
112+
this.catalog_list.table_names(&catalog_name, &schema_name)
113+
{
111114
// Process all tables in this schema
112-
for table_name in schema_provider.table_names() {
115+
for table_name in table_names {
113116
let cache_key = OidCacheKey::Table(
114117
catalog_name.clone(),
115118
schema_name.clone(),
@@ -122,9 +125,11 @@ impl PgAttributeTable {
122125
};
123126
swap_cache.insert(cache_key, table_oid);
124127

125-
if let Some(table) = schema_provider.table(&table_name).await? {
126-
let table_schema = table.schema();
127-
128+
if let Some(table_schema) = this
129+
.catalog_list
130+
.table_schema(&catalog_name, &schema_name, &table_name)
131+
.await?
132+
{
128133
// Add column entries for this table
129134
for (column_idx, field) in table_schema.fields().iter().enumerate()
130135
{
@@ -233,7 +238,7 @@ impl PgAttributeTable {
233238
}
234239
}
235240

236-
impl PartitionStream for PgAttributeTable {
241+
impl<C: CatalogInfo> PartitionStream for PgAttributeTable<C> {
237242
fn schema(&self) -> &SchemaRef {
238243
&self.schema
239244
}

datafusion-postgres/src/pg_catalog/pg_class.rs

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,32 @@ use datafusion::arrow::array::{
66
ArrayRef, BooleanArray, Float64Array, Int16Array, Int32Array, RecordBatch, StringArray,
77
};
88
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
9-
use datafusion::catalog::CatalogProviderList;
9+
use datafusion::datasource::TableType;
1010
use datafusion::error::Result;
1111
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
1212
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
1313
use datafusion::physical_plan::streaming::PartitionStream;
1414
use postgres_types::Oid;
1515
use tokio::sync::RwLock;
1616

17-
use super::{get_table_type_with_name, OidCacheKey};
17+
use crate::pg_catalog::catalog_info::{table_type_to_string, CatalogInfo};
18+
19+
use super::OidCacheKey;
1820

1921
#[derive(Debug, Clone)]
20-
pub(crate) struct PgClassTable {
22+
pub(crate) struct PgClassTable<C> {
2123
schema: SchemaRef,
22-
catalog_list: Arc<dyn CatalogProviderList>,
24+
catalog_list: C,
2325
oid_counter: Arc<AtomicU32>,
2426
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
2527
}
2628

27-
impl PgClassTable {
29+
impl<C: CatalogInfo> PgClassTable<C> {
2830
pub(crate) fn new(
29-
catalog_list: Arc<dyn CatalogProviderList>,
31+
catalog_list: C,
3032
oid_counter: Arc<AtomicU32>,
3133
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
32-
) -> PgClassTable {
34+
) -> Self {
3335
// Define the schema for pg_class
3436
// This matches key columns from PostgreSQL's pg_class
3537
let schema = Arc::new(Schema::new(vec![
@@ -75,7 +77,7 @@ impl PgClassTable {
7577
}
7678

7779
/// Generate record batches based on the current state of the catalog
78-
async fn get_data(this: PgClassTable) -> Result<RecordBatch> {
80+
async fn get_data(this: Self) -> Result<RecordBatch> {
7981
// Vectors to store column data
8082
let mut oids = Vec::new();
8183
let mut relnames = Vec::new();
@@ -124,23 +126,24 @@ impl PgClassTable {
124126
};
125127
swap_cache.insert(cache_key, catalog_oid);
126128

127-
if let Some(catalog) = this.catalog_list.catalog(&catalog_name) {
128-
for schema_name in catalog.schema_names() {
129-
if let Some(schema) = catalog.schema(&schema_name) {
130-
let cache_key =
131-
OidCacheKey::Schema(catalog_name.clone(), schema_name.clone());
132-
let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) {
133-
*oid
134-
} else {
135-
this.oid_counter.fetch_add(1, Ordering::Relaxed)
136-
};
137-
swap_cache.insert(cache_key, schema_oid);
129+
if let Some(schema_names) = this.catalog_list.schema_names(&catalog_name) {
130+
for schema_name in schema_names {
131+
let cache_key = OidCacheKey::Schema(catalog_name.clone(), schema_name.clone());
132+
let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) {
133+
*oid
134+
} else {
135+
this.oid_counter.fetch_add(1, Ordering::Relaxed)
136+
};
137+
swap_cache.insert(cache_key, schema_oid);
138138

139-
// Add an entry for the schema itself (as a namespace)
140-
// (In a full implementation, this would go in pg_namespace)
139+
// Add an entry for the schema itself (as a namespace)
140+
// (In a full implementation, this would go in pg_namespace)
141141

142-
// Now process all tables in this schema
143-
for table_name in schema.table_names() {
142+
// Now process all tables in this schema
143+
if let Some(table_names) =
144+
this.catalog_list.table_names(&catalog_name, &schema_name)
145+
{
146+
for table_name in table_names {
144147
let cache_key = OidCacheKey::Table(
145148
catalog_name.clone(),
146149
schema_name.clone(),
@@ -153,13 +156,20 @@ impl PgClassTable {
153156
};
154157
swap_cache.insert(cache_key, table_oid);
155158

156-
if let Some(table) = schema.table(&table_name).await? {
159+
if let Some(table_schema) = this
160+
.catalog_list
161+
.table_schema(&catalog_name, &schema_name, &table_name)
162+
.await?
163+
{
157164
// Determine the correct table type based on the table provider and context
158-
let table_type =
159-
get_table_type_with_name(&table, &table_name, &schema_name);
165+
let table_type = this
166+
.catalog_list
167+
.table_type(&catalog_name, &schema_name, &table_name)
168+
.await?
169+
.unwrap_or(TableType::Temporary);
160170

161171
// Get column count from schema
162-
let column_count = table.schema().fields().len() as i16;
172+
let column_count = table_schema.fields().len() as i16;
163173

164174
// Add table entry
165175
oids.push(table_oid as i32);
@@ -178,7 +188,7 @@ impl PgClassTable {
178188
relhasindexes.push(false);
179189
relisshareds.push(false);
180190
relpersistences.push("p".to_string()); // Permanent
181-
relkinds.push(table_type.to_string());
191+
relkinds.push(table_type_to_string(&table_type));
182192
relnattses.push(column_count);
183193
relcheckses.push(0);
184194
relhasruleses.push(false);
@@ -244,7 +254,7 @@ impl PgClassTable {
244254
}
245255
}
246256

247-
impl PartitionStream for PgClassTable {
257+
impl<C: CatalogInfo> PartitionStream for PgClassTable<C> {
248258
fn schema(&self) -> &SchemaRef {
249259
&self.schema
250260
}

0 commit comments

Comments
 (0)