diff --git a/datafusion-postgres-cli/src/main.rs b/datafusion-postgres-cli/src/main.rs index f340c1f..9bfdc6e 100644 --- a/datafusion-postgres-cli/src/main.rs +++ b/datafusion-postgres-cli/src/main.rs @@ -6,7 +6,7 @@ use datafusion::execution::options::{ ArrowReadOptions, AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_postgres::pg_catalog::PgCatalogSchemaProvider; +use datafusion_postgres::pg_catalog::setup_pg_catalog; use datafusion_postgres::{serve, ServerOptions}; // Assuming the crate name is `datafusion_postgres` use structopt::StructOpt; @@ -171,11 +171,7 @@ async fn setup_session_context( } // Register pg_catalog - let pg_catalog = PgCatalogSchemaProvider::new(session_context.state().catalog_list().clone()); - session_context - .catalog("datafusion") - .unwrap() - .register_schema("pg_catalog", Arc::new(pg_catalog))?; + setup_pg_catalog(session_context, "datafusion")?; Ok(()) } diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index c6af863..28a1469 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -2,16 +2,20 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::array::{ - ArrayRef, BooleanArray, Float64Array, Int16Array, Int32Array, RecordBatch, StringArray, + as_boolean_array, ArrayRef, BooleanArray, Float64Array, Int16Array, Int32Array, RecordBatch, + StringArray, StringBuilder, }; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::catalog::streaming::StreamingTable; use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider}; +use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::datasource::TableProvider; -use datafusion::error::Result; +use datafusion::error::{DataFusionError, Result}; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; +use datafusion::prelude::{create_udf, SessionContext}; const PG_CATALOG_TABLE_PG_TYPE: &str = "pg_type"; const PG_CATALOG_TABLE_PG_CLASS: &str = "pg_class"; @@ -595,3 +599,77 @@ impl PartitionStream for PgDatabaseTable { )) } } + +pub fn create_current_schemas_udf() -> ScalarUDF { + // Define the function implementation + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let input = as_boolean_array(&args[0]); + + // Create a UTF8 array with a single value + let mut values = vec!["public"]; + // include implicit schemas + if input.value(0) { + values.push("information_schema"); + values.push("pg_catalog"); + } + + let list_array = SingleRowListArrayBuilder::new(Arc::new(StringArray::from(values))); + + let array: ArrayRef = Arc::new(list_array.build_list_array()); + + Ok(ColumnarValue::Array(array)) + }; + + // Wrap the implementation in a scalar function + create_udf( + "current_schemas", + vec![DataType::Boolean], + DataType::List(Arc::new(Field::new("schema", DataType::Utf8, false))), + Volatility::Immutable, + Arc::new(func), + ) +} + +pub fn create_current_schema_udf() -> ScalarUDF { + // Define the function implementation + let func = move |_args: &[ColumnarValue]| { + // Create a UTF8 array with a single value + let mut builder = StringBuilder::new(); + builder.append_value("public"); + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + // Wrap the implementation in a scalar function + create_udf( + "current_schema", + vec![], + DataType::Utf8, + Volatility::Immutable, + Arc::new(func), + ) +} + +/// Install pg_catalog and postgres UDFs to current `SessionContext` +pub fn setup_pg_catalog( + session_context: &SessionContext, + catalog_name: &str, +) -> Result<(), Box> { + let pg_catalog = PgCatalogSchemaProvider::new(session_context.state().catalog_list().clone()); + session_context + .catalog(catalog_name) + .ok_or_else(|| { + DataFusionError::Configuration(format!( + "Catalog not found when registering pg_catalog: {}", + catalog_name + )) + })? + .register_schema("pg_catalog", Arc::new(pg_catalog))?; + + session_context.register_udf(create_current_schema_udf()); + session_context.register_udf(create_current_schemas_udf()); + + Ok(()) +}