diff --git a/Cargo.lock b/Cargo.lock index 064e5f4..2cc6e99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -928,7 +928,7 @@ dependencies = [ "parquet", "rand", "regex", - "sqlparser", + "sqlparser 0.54.0", "tempfile", "tokio", "url", @@ -999,7 +999,7 @@ dependencies = [ "parquet", "paste", "recursive", - "sqlparser", + "sqlparser 0.54.0", "tokio", "web-time", ] @@ -1091,7 +1091,7 @@ dependencies = [ "paste", "recursive", "serde_json", - "sqlparser", + "sqlparser 0.54.0", ] [[package]] @@ -1358,6 +1358,8 @@ dependencies = [ "datafusion", "futures", "pgwire", + "sqlparser 0.55.0", + "tokio", ] [[package]] @@ -1385,7 +1387,7 @@ dependencies = [ "log", "recursive", "regex", - "sqlparser", + "sqlparser 0.54.0", ] [[package]] @@ -3088,6 +3090,16 @@ dependencies = [ "sqlparser_derive", ] +[[package]] +name = "sqlparser" +version = "0.55.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4521174166bac1ff04fe16ef4524c70144cd29682a45978978ca3d7f4e0be11" +dependencies = [ + "log", + "recursive", +] + [[package]] name = "sqlparser_derive" version = "0.3.0" diff --git a/datafusion-postgres-cli/src/main.rs b/datafusion-postgres-cli/src/main.rs index a42317c..5173a25 100644 --- a/datafusion-postgres-cli/src/main.rs +++ b/datafusion-postgres-cli/src/main.rs @@ -49,12 +49,14 @@ async fn main() { let opts = Opt::from_args(); let session_context = SessionContext::new(); + let mut registered_tables = Vec::new(); // Collect table names here for (table_name, table_path) in opts.csv_tables.iter().map(|s| parse_table_def(s.as_ref())) { session_context .register_csv(table_name, table_path, CsvReadOptions::default()) .await .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + registered_tables.push(table_name.to_string()); println!("Loaded {} as table {}", table_path, table_name); } @@ -63,6 +65,7 @@ async fn main() { .register_json(table_name, table_path, NdJsonReadOptions::default()) .await .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + registered_tables.push(table_name.to_string()); println!("Loaded {} as table {}", table_path, table_name); } @@ -75,6 +78,7 @@ async fn main() { .register_arrow(table_name, table_path, ArrowReadOptions::default()) .await .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + registered_tables.push(table_name.to_string()); println!("Loaded {} as table {}", table_path, table_name); } @@ -87,6 +91,7 @@ async fn main() { .register_parquet(table_name, table_path, ParquetReadOptions::default()) .await .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + registered_tables.push(table_name.to_string()); println!("Loaded {} as table {}", table_path, table_name); } @@ -95,12 +100,16 @@ async fn main() { .register_avro(table_name, table_path, AvroReadOptions::default()) .await .unwrap_or_else(|e| panic!("Failed to register table: {table_name}, {e}")); + registered_tables.push(table_name.to_string()); println!("Loaded {} as table {}", table_path, table_name); } - let factory = Arc::new(HandlerFactory(Arc::new(DfSessionService::new( - session_context, - )))); + let service = DfSessionService::new(session_context, registered_tables); // Pass registered_tables + service + .register_udfs() + .await + .unwrap_or_else(|e| panic!("Failed to register UDFs: {e}")); + let factory = Arc::new(HandlerFactory(Arc::new(service))); let server_addr = format!("{}:{}", opts.host, opts.port); let listener = TcpListener::bind(&server_addr).await.unwrap(); diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index 1a4ae7d..4e3134d 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -19,5 +19,7 @@ readme = "../README.md" pgwire = { workspace = true } datafusion = { workspace = true } futures = "0.3" +sqlparser = "0.55" async-trait = "0.1" chrono = { version = "0.4", features = ["std"] } +tokio = { version = "1.0", features = ["sync"] } diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 6067e37..7912c69 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -1,19 +1,28 @@ +// handlers.rs + use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; -use datafusion::arrow::datatypes::DataType; -use datafusion::logical_expr::LogicalPlan; +use datafusion::arrow::array::StringArray; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::logical_expr::{create_udf, ColumnarValue, LogicalPlan, Volatility}; use datafusion::prelude::*; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; -use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response}; -use pgwire::api::stmt::QueryParser; -use pgwire::api::stmt::StoredStatement; +use pgwire::api::results::{ + DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, +}; +use pgwire::api::stmt::{QueryParser, StoredStatement}; use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::{PgWireError, PgWireResult}; +use sqlparser::ast::{Expr, Ident, ObjectName, ObjectNamePart, Statement}; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::Parser as SqlParser; +use tokio::sync::RwLock; use crate::datatypes::{self, into_pg_type}; @@ -50,46 +59,255 @@ impl PgWireServerHandlers for HandlerFactory { } pub struct DfSessionService { - session_context: Arc, - parser: Arc, + pub session_context: Arc>, + pub parser: Arc, + custom_session_vars: Arc>>, + registered_tables: Arc>, } impl DfSessionService { - pub fn new(session_context: SessionContext) -> DfSessionService { - let session_context = Arc::new(session_context); + pub fn new( + session_context: SessionContext, + registered_tables: Vec, + ) -> DfSessionService { + let session_context = Arc::new(RwLock::new(session_context)); let parser = Arc::new(Parser { session_context: session_context.clone(), + registered_tables: Arc::new(registered_tables.clone()), // Pass to Parser }); DfSessionService { session_context, parser, + custom_session_vars: Arc::new(RwLock::new(HashMap::new())), + registered_tables: Arc::new(registered_tables), } } -} -#[async_trait] -impl SimpleQueryHandler for DfSessionService { - async fn do_query<'a, C>( - &self, - _client: &mut C, - query: &'a str, - ) -> PgWireResult>> - where - C: ClientInfo + Unpin + Send + Sync, - { - let ctx = &self.session_context; - let df = ctx - .sql(query) + pub async fn register_udfs(&self) -> datafusion::error::Result<()> { + let mut ctx = self.session_context.write().await; + register_current_schemas_udf(&mut ctx)?; + Ok(()) + } + + async fn session_var(&self, key: &str, default: &str) -> String { + self.custom_session_vars + .read() .await + .get(key) + .cloned() + .unwrap_or_else(|| default.to_string()) + } + + async fn handle_set(&self, variable: &ObjectName, value: &[Expr]) -> PgWireResult<()> { + let var_name = variable + .0 + .iter() + .map(|ident| ident.to_string()) + .collect::() + .to_lowercase(); + + let value_str = match value.first() { + Some(Expr::Value(v)) => match &v.value { + sqlparser::ast::Value::SingleQuotedString(s) + | sqlparser::ast::Value::DoubleQuotedString(s) => s.clone(), + sqlparser::ast::Value::Number(n, _) => n.to_string(), + _ => v.to_string(), + }, + Some(expr) => expr.to_string(), + None => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "22023".to_string(), + "SET requires a value".to_string(), + ), + ))); + } + }; + + match var_name.as_str() { + "timezone" => { + let mut sc_guard = self.session_context.write().await; + let mut config = sc_guard.state().config().options().clone(); + config.execution.time_zone = Some(value_str); + let new_context = SessionContext::new_with_config(config.into()); + let old_catalog_names = sc_guard.catalog_names(); + for catalog_name in old_catalog_names { + if let Some(catalog) = sc_guard.catalog(&catalog_name) { + for schema_name in catalog.schema_names() { + if let Some(schema) = catalog.schema(&schema_name) { + for table_name in schema.table_names() { + if let Ok(Some(table)) = schema.table(&table_name).await { + new_context + .register_table(&table_name, table) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + } + } + } + } + } + } + *sc_guard = new_context; + Ok(()) + } + "client_encoding" + | "search_path" + | "application_name" + | "datestyle" + | "client_min_messages" + | "extra_float_digits" + | "standard_conforming_strings" + | "check_function_bodies" + | "transaction_read_only" + | "transaction_isolation" => { + let mut vars = self.custom_session_vars.write().await; + vars.insert(var_name, value_str); + Ok(()) + } + _ => Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42704".to_string(), + format!("Unrecognized configuration parameter '{}'", var_name), + ), + ))), + } + } + + async fn handle_show<'a>(&self, variable: &[Ident]) -> PgWireResult> { + let var_name = variable + .iter() + .map(|ident| ident.to_string()) + .collect::() + .to_lowercase(); + + let sc_guard = self.session_context.read().await; + let config = sc_guard.state().config().options().clone(); + + let value = match var_name.as_str() { + "timezone" | "time" => config + .execution + .time_zone + .clone() + .unwrap_or_else(|| "UTC".to_string()), + "client_encoding" => self.session_var("client_encoding", "UTF8").await, + "search_path" => self.session_var("search_path", "public").await, + "application_name" => self.session_var("application_name", "").await, + "datestyle" => self.session_var("datestyle", "ISO, MDY").await, + "client_min_messages" => self.session_var("client_min_messages", "notice").await, + "extra_float_digits" => self.session_var("extra_float_digits", "3").await, + "standard_conforming_strings" => { + self.session_var("standard_conforming_strings", "on").await + } + "check_function_bodies" => self.session_var("check_function_bodies", "off").await, + "transaction_read_only" => self.session_var("transaction_read_only", "off").await, + "transaction_isolation" => { + self.session_var("transaction_isolation", "read committed") + .await + } + "server_version" => "14.0".to_string(), + "server_version_num" => "140000".to_string(), + "server_encoding" => "UTF8".to_string(), + "is_superuser" => "off".to_string(), + "lc_messages" => "en_US.UTF-8".to_string(), + "lc_monetary" => "en_US.UTF-8".to_string(), + "lc_numeric" => "en_US.UTF-8".to_string(), + "lc_time" => "en_US.UTF-8".to_string(), + "all" => { + let mut names = Vec::new(); + let mut values = Vec::new(); + + if let Some(tz) = &config.execution.time_zone { + names.push("timezone".to_string()); + values.push(tz.clone()); + } + + let custom_vars = self.custom_session_vars.read().await; + for (name, value) in custom_vars.iter() { + names.push(name.clone()); + values.push(value.clone()); + } + + let defaults = vec![ + ("client_encoding", "UTF8"), + ("search_path", "public"), + ("application_name", ""), + ("datestyle", "ISO, MDY"), + ("client_min_messages", "notice"), + ("extra_float_digits", "3"), + ("standard_conforming_strings", "on"), + ("check_function_bodies", "off"), + ("transaction_read_only", "off"), + ("transaction_isolation", "read committed"), + ("server_version", "14.0"), + ("server_version_num", "140000"), + ("server_encoding", "UTF8"), + ("is_superuser", "off"), + ("lc_messages", "en_US.UTF-8"), + ("lc_monetary", "en_US.UTF-8"), + ("lc_numeric", "en_US.UTF-8"), + ("lc_time", "en_US.UTF-8"), + ("time", "UTC"), + ]; + + for (k, v) in defaults { + if !names.contains(&k.to_string()) { + names.push(k.to_string()); + values.push(v.to_string()); + } + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new( + "setting", + DataType::List(Box::new(Field::new("item", DataType::Utf8, true)).into()), + false, + ), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(names)), + Arc::new(StringArray::from(values)), + ], + ) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + let df = sc_guard + .read_batch(batch) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + return datatypes::encode_dataframe(df, &Format::UnifiedText).await; + } + _ => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42704".to_string(), + format!("Unrecognized configuration parameter '{}'", var_name), + ), + ))); + } + }; + + let schema = Arc::new(Schema::new(vec![Field::new( + &var_name, + DataType::Utf8, + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(vec![value]))]) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let df = sc_guard + .read_batch(batch) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; - Ok(vec![Response::Query(resp)]) + datatypes::encode_dataframe(df, &Format::UnifiedText).await } } pub struct Parser { - session_context: Arc, + pub session_context: Arc>, + pub registered_tables: Arc>, // Add registered_tables to Parser } #[async_trait] @@ -97,25 +315,192 @@ impl QueryParser for Parser { type Statement = LogicalPlan; async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult { - let context = &self.session_context; - let state = context.state(); + let sql_lower = sql.to_lowercase(); + if sql_lower.contains("pg_catalog.pg_class") { + let num_tables = self.registered_tables.len(); + let oid_array = + StringArray::from((1..=num_tables).map(|i| i.to_string()).collect::>()); + let relname_array = StringArray::from(self.registered_tables.as_ref().clone()); + let relnamespace_array = StringArray::from(vec!["public"; num_tables]); + let relkind_array = StringArray::from(vec!["r"; num_tables]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("oid", DataType::Utf8, false), + Field::new("relname", DataType::Utf8, false), + Field::new("relnamespace", DataType::Utf8, false), + Field::new("relkind", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(oid_array), + Arc::new(relname_array), + Arc::new(relnamespace_array), + Arc::new(relkind_array), + ], + ) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + let sc_guard = self.session_context.read().await; + let df = sc_guard + .read_batch(batch) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let logical_plan = df.logical_plan().clone(); + return Ok(logical_plan); + } + let sc_guard = self.session_context.read().await; + let state = sc_guard.state(); let logical_plan = state .create_logical_plan(sql) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let optimised = state + let optimized = state .optimize(&logical_plan) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + Ok(optimized) + } +} + +#[async_trait] +impl SimpleQueryHandler for DfSessionService { + async fn do_query<'a, C>( + &self, + _client: &mut C, + query: &'a str, + ) -> PgWireResult>> + where + C: ClientInfo + Unpin + Send + Sync, + { + let query_trimmed = query.trim(); + let query_lower = query_trimmed.to_lowercase(); + + if query_lower.starts_with("select") && query_lower.contains("current_schemas") { + let string_array = StringArray::from(vec!["public"]); + let field = Field::new("current_schemas", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(string_array)]) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let sc_guard = self.session_context.read().await; + let df = sc_guard + .read_batch(batch) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let encoded = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; + return Ok(vec![Response::Query(encoded)]); + } + + if query_lower.contains("pg_catalog.pg_namespace") { + let nspname_array = StringArray::from(vec!["public"]); + let nspissystem_array = StringArray::from(vec!["false"]); + let schema = Arc::new(Schema::new(vec![ + Field::new("nspname", DataType::Utf8, false), + Field::new("nspissystem", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(nspname_array), Arc::new(nspissystem_array)], + ) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let sc_guard = self.session_context.read().await; + let df = sc_guard + .read_batch(batch) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let encoded = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; + return Ok(vec![Response::Query(encoded)]); + } + + if query_lower.contains("pg_catalog.pg_class") { + let num_tables = self.registered_tables.len(); + let oid_array = + StringArray::from((1..=num_tables).map(|i| i.to_string()).collect::>()); + let relname_array = StringArray::from(self.registered_tables.as_ref().clone()); + let relnamespace_array = StringArray::from(vec!["public"; num_tables]); + let relkind_array = StringArray::from(vec!["r"; num_tables]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("oid", DataType::Utf8, false), + Field::new("relname", DataType::Utf8, false), + Field::new("relnamespace", DataType::Utf8, false), + Field::new("relkind", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(oid_array), + Arc::new(relname_array), + Arc::new(relnamespace_array), + Arc::new(relkind_array), + ], + ) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let sc_guard = self.session_context.read().await; + let df = sc_guard + .read_batch(batch) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let encoded = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; + return Ok(vec![Response::Query(encoded)]); + } + + if query_lower.starts_with("set time zone") { + let parts: Vec<&str> = query_trimmed.split_whitespace().collect(); + if parts.len() >= 4 { + let tz = parts[3].trim_matches('\'').trim_matches('"'); + let object_name = + ObjectName(vec![ObjectNamePart::Identifier(Ident::new("timezone"))]); + let expr = + Expr::Value(sqlparser::ast::Value::SingleQuotedString(tz.to_string()).into()); + self.handle_set(&object_name, &[expr]).await?; + return Ok(vec![Response::Execution(pgwire::api::results::Tag::new( + "SET", + ))]); + } + } + + let dialect = GenericDialect {}; + let stmts = SqlParser::parse_sql(&dialect, query) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let mut responses = Vec::with_capacity(stmts.len()); - Ok(optimised) + for statement in stmts { + let stmt_string = statement.to_string().trim().to_owned(); + if stmt_string.is_empty() { + continue; + } + match statement { + Statement::SetVariable { + variables, value, .. + } => { + let var = match variables { + sqlparser::ast::OneOrManyWithParens::One(ref name) => name, + sqlparser::ast::OneOrManyWithParens::Many(ref names) => { + names.first().unwrap() + } + }; + self.handle_set(var, &value).await?; + responses.push(Response::Execution(pgwire::api::results::Tag::new("SET"))); + } + Statement::ShowVariable { variable } => { + let resp = self.handle_show(&variable).await?; + responses.push(Response::Query(resp)); + } + _ => { + let sc_guard = self.session_context.read().await; + let df = sc_guard + .sql(&stmt_string) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; + responses.push(Response::Query(resp)); + } + } + } + Ok(responses) } } #[async_trait] impl ExtendedQueryHandler for DfSessionService { type Statement = LogicalPlan; - type QueryParser = Parser; fn query_parser(&self) -> Arc { @@ -131,23 +516,28 @@ impl ExtendedQueryHandler for DfSessionService { C: ClientInfo + Unpin + Send + Sync, { let plan = &target.statement; - let schema = plan.schema(); let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?; let params = plan .get_parameter_types() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let mut param_types = Vec::with_capacity(params.len()); + for param_type in ordered_param_types(¶ms).iter() { if let Some(datatype) = param_type { - let pgtype = into_pg_type(datatype)?; - param_types.push(pgtype); + match datatype { + DataType::List(inner) if matches!(inner.data_type(), DataType::Utf8) => { + param_types.push(Type::TEXT_ARRAY); + } + _ => { + let pgtype = into_pg_type(datatype)?; + param_types.push(pgtype); + } + } } else { - param_types.push(Type::UNKNOWN); + param_types.push(Type::TEXT_ARRAY); // Default for pgcli's relkind array } } - Ok(DescribeStatementResponse::new(param_types, fields)) } @@ -163,7 +553,6 @@ impl ExtendedQueryHandler for DfSessionService { let format = &target.result_column_format; let schema = plan.schema(); let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?; - Ok(DescribePortalResponse::new(fields)) } @@ -176,22 +565,47 @@ impl ExtendedQueryHandler for DfSessionService { where C: ClientInfo + Unpin + Send + Sync, { - let plan = &portal.statement.statement; + let stmt_string = portal.statement.id.clone(); + let stmt_upper = stmt_string.to_uppercase(); + + if stmt_upper.starts_with("SET ") { + let dialect = GenericDialect {}; + let stmts = SqlParser::parse_sql(&dialect, &stmt_string) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + if let Statement::SetVariable { + variables, value, .. + } = &stmts[0] + { + let var = match variables { + sqlparser::ast::OneOrManyWithParens::One(ref name) => name, + sqlparser::ast::OneOrManyWithParens::Many(ref names) => names.first().unwrap(), + }; + self.handle_set(var, value).await?; + return Ok(Response::Execution(pgwire::api::results::Tag::new("SET"))); + } + } else if stmt_upper.starts_with("SHOW ") { + let dialect = GenericDialect {}; + let stmts = SqlParser::parse_sql(&dialect, &stmt_string) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + if let Statement::ShowVariable { variable } = &stmts[0] { + let resp = self.handle_show(variable).await?; + return Ok(Response::Query(resp)); + } + } + let plan = &portal.statement.statement; let param_types = plan .get_parameter_types() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let param_values = datatypes::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; - let plan = plan .clone() .replace_params_with_values(¶m_values) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let dataframe = self - .session_context + let sc_guard = self.session_context.read().await; + let dataframe = sc_guard .execute_logical_plan(plan) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -202,10 +616,33 @@ impl ExtendedQueryHandler for DfSessionService { } fn ordered_param_types(types: &HashMap>) -> Vec> { - // Datafusion stores the parameters as a map. In our case, the keys will be - // `$1`, `$2` etc. The values will be the parameter types. + // Datafusion stores the parameters as a map. In our case, the keys will be + // `$1`, `$2` etc. The values will be the parameter types. + let mut types_vec = types.iter().collect::>(); + types_vec.sort_by(|a, b| a.0.cmp(b.0)); + types_vec.into_iter().map(|pt| pt.1.as_ref()).collect() +} - let mut types = types.iter().collect::>(); - types.sort_by(|a, b| a.0.cmp(b.0)); - types.into_iter().map(|pt| pt.1.as_ref()).collect() +fn register_current_schemas_udf(ctx: &mut SessionContext) -> datafusion::error::Result<()> { + let current_schemas_fn = Arc::new( + move |args: &[ColumnarValue]| -> datafusion::error::Result { + let num_rows = match &args[0] { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }; + let string_array = StringArray::from(vec!["public"; num_rows]); + Ok(ColumnarValue::Array(Arc::new(string_array))) + }, + ); + + let udf = create_udf( + "current_schemas", + vec![DataType::Boolean], + DataType::Utf8, + Volatility::Immutable, + current_schemas_fn, + ); + + ctx.register_udf(udf); + Ok(()) }