Skip to content

Commit 9896a73

Browse files
committed
feat: add udf current_schema and current_schemas
1 parent f99e0eb commit 9896a73

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

datafusion-postgres-cli/src/main.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use datafusion::execution::options::{
66
ArrowReadOptions, AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions,
77
};
88
use datafusion::prelude::{SessionConfig, SessionContext};
9-
use datafusion_postgres::pg_catalog::PgCatalogSchemaProvider;
9+
use datafusion_postgres::pg_catalog::setup_pg_catalog;
1010
use datafusion_postgres::{serve, ServerOptions}; // Assuming the crate name is `datafusion_postgres`
1111
use structopt::StructOpt;
1212

@@ -171,11 +171,7 @@ async fn setup_session_context(
171171
}
172172

173173
// Register pg_catalog
174-
let pg_catalog = PgCatalogSchemaProvider::new(session_context.state().catalog_list().clone());
175-
session_context
176-
.catalog("datafusion")
177-
.unwrap()
178-
.register_schema("pg_catalog", Arc::new(pg_catalog))?;
174+
setup_pg_catalog(session_context)?;
179175

180176
Ok(())
181177
}

datafusion-postgres/src/pg_catalog.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@ use std::sync::Arc;
22

33
use async_trait::async_trait;
44
use datafusion::arrow::array::{
5-
ArrayRef, BooleanArray, Float64Array, Int16Array, Int32Array, RecordBatch, StringArray,
5+
as_boolean_array, ArrayRef, BooleanArray, Float64Array, Int16Array, Int32Array, RecordBatch,
6+
StringArray, StringBuilder,
67
};
78
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
89
use datafusion::catalog::streaming::StreamingTable;
910
use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider};
11+
use datafusion::common::utils::SingleRowListArrayBuilder;
1012
use datafusion::datasource::TableProvider;
1113
use datafusion::error::Result;
1214
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
15+
use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility};
1316
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
1417
use datafusion::physical_plan::streaming::PartitionStream;
18+
use datafusion::prelude::{create_udf, SessionContext};
1519

1620
const PG_CATALOG_TABLE_PG_TYPE: &str = "pg_type";
1721
const PG_CATALOG_TABLE_PG_CLASS: &str = "pg_class";
@@ -595,3 +599,69 @@ impl PartitionStream for PgDatabaseTable {
595599
))
596600
}
597601
}
602+
603+
pub fn create_current_schemas_udf() -> ScalarUDF {
604+
// Define the function implementation
605+
let func = move |args: &[ColumnarValue]| {
606+
let args = ColumnarValue::values_to_arrays(args)?;
607+
let input = as_boolean_array(&args[0]);
608+
609+
// Create a UTF8 array with a single value
610+
let mut values = vec!["public"];
611+
// include implicit schemas
612+
if input.value(0) {
613+
values.push("information_schema");
614+
values.push("pg_catalog");
615+
}
616+
617+
let list_array = SingleRowListArrayBuilder::new(Arc::new(StringArray::from(values)));
618+
619+
let array: ArrayRef = Arc::new(list_array.build_list_array());
620+
621+
Ok(ColumnarValue::Array(array))
622+
};
623+
624+
// Wrap the implementation in a scalar function
625+
create_udf(
626+
"current_schemas",
627+
vec![DataType::Boolean],
628+
DataType::List(Arc::new(Field::new("schema", DataType::Utf8, false))),
629+
Volatility::Immutable,
630+
Arc::new(func),
631+
)
632+
}
633+
634+
pub fn create_current_schema_udf() -> ScalarUDF {
635+
// Define the function implementation
636+
let func = move |_args: &[ColumnarValue]| {
637+
// Create a UTF8 array with a single value
638+
let mut builder = StringBuilder::new();
639+
builder.append_value("public");
640+
let array: ArrayRef = Arc::new(builder.finish());
641+
642+
Ok(ColumnarValue::Array(array))
643+
};
644+
645+
// Wrap the implementation in a scalar function
646+
create_udf(
647+
"current_schema",
648+
vec![],
649+
DataType::Utf8,
650+
Volatility::Immutable,
651+
Arc::new(func),
652+
)
653+
}
654+
655+
/// Install pg_catalog and postgres UDFs to current `SessionContext`
656+
pub fn setup_pg_catalog(session_context: &SessionContext) -> Result<()> {
657+
let pg_catalog = PgCatalogSchemaProvider::new(session_context.state().catalog_list().clone());
658+
session_context
659+
.catalog("datafusion")
660+
.unwrap()
661+
.register_schema("pg_catalog", Arc::new(pg_catalog))?;
662+
663+
session_context.register_udf(create_current_schema_udf());
664+
session_context.register_udf(create_current_schemas_udf());
665+
666+
Ok(())
667+
}

0 commit comments

Comments
 (0)