From 07f2193363e104f1f654743387da4e7ad3de9286 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 27 Aug 2025 20:01:18 +0800 Subject: [PATCH 01/13] feat: implement dbeaver startup queries one-by-one --- Cargo.lock | 1 + datafusion-postgres/Cargo.toml | 3 +++ datafusion-postgres/src/handlers.rs | 15 +++++++-------- datafusion-postgres/src/pg_catalog.rs | 21 +++++++++++++++++++++ datafusion-postgres/tests/dbeaver.rs | 26 +++++++++++++++++++++++--- 5 files changed, 55 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 33bc38a..d311afc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1494,6 +1494,7 @@ dependencies = [ "bytes", "chrono", "datafusion", + "env_logger", "futures", "getset", "log", diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index c87f968..4889673 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -28,3 +28,6 @@ tokio = { version = "1.47", features = ["sync", "net"] } tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } rustls-pemfile = "2.0" rustls-pki-types = "1.0" + +[dev-dependencies] +env_logger = "0.11" diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 66e6ebe..fe70ce9 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -6,6 +6,7 @@ use async_trait::async_trait; use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::*; +use log::warn; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::auth::StartupHandler; use pgwire::api::portal::{Format, Portal}; @@ -198,14 +199,12 @@ impl DfSessionService { } } else { // pass SET query to datafusion - let df = self - .session_context - .sql(query_lower) - .await - .map_err(|err| PgWireError::ApiError(Box::new(err)))?; - - let resp = df::encode_dataframe(df, &Format::UnifiedText).await?; - Ok(Some(Response::Query(resp))) + if let Err(e) = self.session_context.sql(query_lower).await { + warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored"); + } + + // Always return SET success + Ok(Some(Response::Execution(Tag::new("SET")))) } } else { Ok(None) diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 90e3a31..426fcac 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -869,6 +869,26 @@ pub fn create_format_type_udf() -> ScalarUDF { ) } +pub fn create_session_user_udf() -> ScalarUDF { + let func = move |_args: &[ColumnarValue]| { + let mut builder = StringBuilder::new(); + // TODO: return real user + builder.append_value("postgres"); + + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + create_udf( + "session_user", + vec![], + DataType::Utf8, + Volatility::Stable, + Arc::new(func), + ) +} + /// Install pg_catalog and postgres UDFs to current `SessionContext` pub fn setup_pg_catalog( session_context: &SessionContext, @@ -892,6 +912,7 @@ pub fn setup_pg_catalog( session_context.register_udf(create_has_table_privilege_2param_udf()); session_context.register_udf(create_pg_table_is_visible()); session_context.register_udf(create_format_type_udf()); + session_context.register_udf(create_session_user_udf()); Ok(()) } diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index b1776d7..f6ca62a 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -3,12 +3,32 @@ mod common; use common::*; use pgwire::api::query::SimpleQueryHandler; +const DBEAVER_QUERIES: &[&str] = &[ + "SET extra_float_digits = 3", + "SET application_name = 'PostgreSQL JDBC Driver'", + "SET application_name = 'DBeaver 25.1.5 - Main '", + "SELECT current_schema(),session_user", + "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname", + "SELECT n.nspsname = ANY(current_schemas(true)), n.nspsname, t.typname FROM pg_catalog.pg_type t JOIN pg_catalog.pg_namespace n ON t.typrelid = n.oid WHERE pg_type.oid = 1034", + "SHOW search_path", + "SELECT db.oid,db.* FROM pg_catalog.pg_database db WHERE datname='postgres'", + "SELECT * FROM pg_catalog.pg_settinngs where name='standard_conforming_strings'", + "SELECT string_agg(word, ',' ) from pg_catalog.pg_get_keywords() where word <> ALL ('{a,abs,absolute,action,ada,add,admin,after,all,allocate,alter,aIways,and,any,are,array,as,asc,asenstitive,assertion,assignment,asymmetric,at,atomic,attribute,attributes,authorization,avg,before,begin,bernoulli,between,bigint,binary,blob,boolean,both,breaadth,by,c,call,called,cardinaliity,cascade,cascaded,case,cast,catalog,catalog_name,ceil,ceiling,chain,char,char_length,character,character_length,character_set_catalog,character_set_name,character_set_schema,characteristics,characters,check,checkeed,class_origin,clob,close,coalesce,coboI,code_units,collate,collation,collaition_catalog,collaition_name,collaition_schema,collect,colum,column_name,command_function,command_function_code,commit,committed,condiition,condiition_number,connect,connection_name,constraint,constraint_catalog,constraint_name,constraint_schema,constraints,constructors,contains,continue,convert,corr,correspondiing,count,covar_pop,covar_samp,create,cross,cube,cume_dist,current,current_collation,current_date,current_default_transfom_group,current_path,current_role,current_time,current_timestamp,current_transfom_group_for_type,current_user,cursor,cursor_name,cycle,data,date,datetime_interval_code,datetime_interval_precision,day,deallocate,dec,decimaI,declare,default,defaults,not,null,nullable,nullif,nulls,number,numeric,object,octeet_length,octets,of,old,on,only,open,option,options,or,order,ordering,ordinaliity,others,out,outer,output,over,overlaps,overlay,overriding,pad,parameter,parameter_mode,parameter_name,parameter_ordinal_position,parameter_speciific_catalog,parameter_speciific_name,parameter_speciific_schema,partiaI,partitioon,pascal,path,percent_rank,percentile_cont,percentile_disc,placing,pli,position,power,preceding,precision,prepare,preseerv,primary,prior,privileges,procedure,public,range,rank,read,reads,real,recursivve,ref,references,referencing,regr_avgx,regr_avgy,regr_count,regr_intercept,regr_r2,regr_slope,regr_sxx,regr_sxy,regr_sy y,relative,release,repeatable,restart,result,retun,returned_cardinality,returned_length,returned_octeet_length,returned_sqlstate,returns,revoe,right,role,rollback,rollup,routine,routine_catalog,routine_name,routine_schema,row,row_count,row_number,rows,savepoint,scale,schema,schema_name,scope_catalog,scope_name,scope_schema,scroll,search,second,section,security,select,self,sensitive,sequence,seriializeable,server_name,session,session_user,set,sets,similar,simple,size,smalIint,some,source,space,specifiic,speciific_name,speciifictype,sql,sqlexception,sqlstate,sqlwarning,sqrt,start,state,statement,static,stddev_pop,stddev_samp,structure,style,subclass_origin,submultiset,substring,sum,symmetric,system,system_user,table,table_name,tablesample,temporary,then,ties,time,timesamp,timezone_hour,timezone_minute,to,top_level_count,trailing,transaction,transaction_active,transactions_committed,transactions_rolled_back,transfor,transforms,translate,translation,treat,trigger,trigger_catalog,trigger_name,trigger_schema,trim,true,type,unbounde,undefined,uncommitted,under,union,unique,unknown,unnaamed,unnest,update,upper,usage,user,user_defined_type_catalog,user_defined_type_code,user_defined_type_name,user_defined_type_schema,using,value,values,var_pop,var_samp,varchar,varying,view,when,whenever,where,width_bucket,window,with,within,without,work,write,year,zone", + "SELECT version()", + "SELECT * FROM pg_catalog.pg_enum WHERE 1<>1 LIMIT 1", + "SELECT reltype FROM pg_catalog.pg_class WHERE 1<>1 LIMIT 1", + "SELECT t.oid,t.*,c.relkind,format_type(nullif(t.typbasetype, 0), t.typtypmod) as base_type_name, d.description FROM pg_catalog.pg_type t LEFT OUTER JOIN pg_catalog.pg_type et ON et.oid=t.typelem LEFT OUTER JOIN pg_catalog.pg_class c ON c.oid=t.typrelid LEFT OUTER JOIN pg_catalog.pg_description d ON t.oid=d.objoid WHERE t.typname IS NOT NULL AND (c.relkind IS NULL OR c.relkind = 'c') AND (et.typcategory IS NULL OR et.typcategory <> 'C')", +]; + #[tokio::test] pub async fn test_dbeaver_startup_sql() { + env_logger::init(); let service = setup_handlers(); let mut client = MockClient::new(); - SimpleQueryHandler::do_query(&service, &mut client, "SELECT 1") - .await - .expect("failed to run sql"); + for query in DBEAVER_QUERIES { + SimpleQueryHandler::do_query(&service, &mut client, query) + .await + .expect(&format!("failed to run sql: {query}")); + } } From 975fef0be775473684bcad12fd162bb3664a1d49 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Thu, 28 Aug 2025 17:40:57 +0800 Subject: [PATCH 02/13] fix: correct query feed into handler --- datafusion-postgres/src/handlers.rs | 7 ++++--- datafusion-postgres/src/sql.rs | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 56e1997..02083f7 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -324,7 +324,8 @@ impl SimpleQueryHandler for DfSessionService { statement = rewrite(statement, &self.sql_rewrite_rules); // TODO: improve statement check by using statement directly - let query_lower = statement.to_string().to_lowercase().trim().to_string(); + let query = statement.to_string(); + let query_lower = query.to_lowercase().trim().to_string(); // Check permissions for the query (skip for SET, transaction, and SHOW statements) if !query_lower.starts_with("set") @@ -336,7 +337,7 @@ impl SimpleQueryHandler for DfSessionService { && !query_lower.starts_with("abort") && !query_lower.starts_with("show") { - self.check_query_permission(client, query).await?; + self.check_query_permission(client, &query).await?; } if let Some(resp) = self.try_respond_set_statements(&query_lower).await? { @@ -366,7 +367,7 @@ impl SimpleQueryHandler for DfSessionService { ))); } - let df_result = self.session_context.sql(query).await; + let df_result = self.session_context.sql(&query).await; // Handle query execution errors and transaction state let df = match df_result { diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 9002811..8b75b91 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -170,5 +170,14 @@ mod tests { statement.to_string(), "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id" ); + + let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname"; + let statement = parse(sql).expect("Failed to parse").remove(0); + + let statement = rewrite(statement, &rules); + assert_eq!( + statement.to_string(), + "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname" + ); } } From 95f048763cdb439132ad59852718bfa57265b51c Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sat, 30 Aug 2025 22:29:46 +0800 Subject: [PATCH 03/13] chore: correct test case --- datafusion-postgres/tests/dbeaver.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index f6ca62a..44543b3 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -8,7 +8,7 @@ const DBEAVER_QUERIES: &[&str] = &[ "SET application_name = 'PostgreSQL JDBC Driver'", "SET application_name = 'DBeaver 25.1.5 - Main '", "SELECT current_schema(),session_user", - "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname", + "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname", "SELECT n.nspsname = ANY(current_schemas(true)), n.nspsname, t.typname FROM pg_catalog.pg_type t JOIN pg_catalog.pg_namespace n ON t.typrelid = n.oid WHERE pg_type.oid = 1034", "SHOW search_path", "SELECT db.oid,db.* FROM pg_catalog.pg_database db WHERE datname='postgres'", From 3cb9adce67bf1904f1f5d11fd7931043abf62b81 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 31 Aug 2025 18:25:19 +0800 Subject: [PATCH 04/13] test: update some test case --- datafusion-postgres/src/sql.rs | 1 + datafusion-postgres/tests/dbeaver.rs | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 12c3f1a..703407f 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -272,6 +272,7 @@ impl RemoveUnsupportedTypes { pub fn new() -> Self { let mut unsupported_types = HashSet::new(); unsupported_types.insert("regclass".to_owned()); + unsupported_types.insert("regproc".to_owned()); Self { unsupported_types } } diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index 44543b3..0793fe9 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -9,7 +9,8 @@ const DBEAVER_QUERIES: &[&str] = &[ "SET application_name = 'DBeaver 25.1.5 - Main '", "SELECT current_schema(),session_user", "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname", - "SELECT n.nspsname = ANY(current_schemas(true)), n.nspsname, t.typname FROM pg_catalog.pg_type t JOIN pg_catalog.pg_namespace n ON t.typrelid = n.oid WHERE pg_type.oid = 1034", + "SELECT n.nspname = ANY(current_schemas(true)), n.nspname, t.typname FROM pg_catalog.pg_type t JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid WHERE t.oid = 1034", + "SELECT typinput='pg_catalog.array_in'::regproc as is_array, typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN (select ns.oid as nspoid, ns.nspname, r.r from pg_namespace as ns join ( select s.r, (current_schemas(false))[s.r] as nspname from generate_series(1, array_upper(current_schemas(false), 1)) as s(r) ) as r using ( nspname ) ) as sp ON sp.nspoid = typnamespace WHERE pg_type.oid = 1034 ORDER BY sp.r, pg_type.oid DESC", "SHOW search_path", "SELECT db.oid,db.* FROM pg_catalog.pg_database db WHERE datname='postgres'", "SELECT * FROM pg_catalog.pg_settinngs where name='standard_conforming_strings'", From 08bdc60e0a16848eae0cb9bbda31ba6e2427577c Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 31 Aug 2025 18:33:47 +0800 Subject: [PATCH 05/13] feat: add rewrite rule to transform any operation --- datafusion-postgres/src/handlers.rs | 3 +- datafusion-postgres/src/sql.rs | 97 +++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 00ae49d..028e258 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; use crate::sql::{ parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes, - ResolveUnqualifiedIdentifer, SqlStatementRewriteRule, + ResolveUnqualifiedIdentifer, RewriteArrayAnyOperation, SqlStatementRewriteRule, }; use async_trait::async_trait; use datafusion::arrow::datatypes::DataType; @@ -81,6 +81,7 @@ impl DfSessionService { Arc::new(AliasDuplicatedProjectionRewrite), Arc::new(ResolveUnqualifiedIdentifer), Arc::new(RemoveUnsupportedTypes::new()), + Arc::new(RewriteArrayAnyOperation), ]; let parser = Arc::new(Parser { session_context: session_context.clone(), diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 703407f..596c4a1 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -2,8 +2,15 @@ use std::collections::HashSet; use std::ops::ControlFlow; use std::sync::Arc; +use datafusion::sql::sqlparser::ast::BinaryOperator; use datafusion::sql::sqlparser::ast::Expr; +use datafusion::sql::sqlparser::ast::Function; +use datafusion::sql::sqlparser::ast::FunctionArg; +use datafusion::sql::sqlparser::ast::FunctionArgExpr; +use datafusion::sql::sqlparser::ast::FunctionArgumentList; +use datafusion::sql::sqlparser::ast::FunctionArguments; use datafusion::sql::sqlparser::ast::Ident; +use datafusion::sql::sqlparser::ast::ObjectName; use datafusion::sql::sqlparser::ast::OrderByKind; use datafusion::sql::sqlparser::ast::Query; use datafusion::sql::sqlparser::ast::Select; @@ -13,6 +20,7 @@ use datafusion::sql::sqlparser::ast::SetExpr; use datafusion::sql::sqlparser::ast::Statement; use datafusion::sql::sqlparser::ast::TableFactor; use datafusion::sql::sqlparser::ast::TableWithJoins; +use datafusion::sql::sqlparser::ast::UnaryOperator; use datafusion::sql::sqlparser::ast::Value; use datafusion::sql::sqlparser::ast::VisitMut; use datafusion::sql::sqlparser::ast::VisitorMut; @@ -327,6 +335,72 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes { } } +/// Rewrite Postgres's ANY operator to array_contains +#[derive(Debug)] +pub struct RewriteArrayAnyOperation; + +struct RewriteArrayAnyOperationVisitor; + +impl RewriteArrayAnyOperationVisitor { + fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr { + Expr::Function(Function { + name: ObjectName::from(vec![Ident::new("array_contains")]), + args: FunctionArguments::List(FunctionArgumentList { + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(right.clone())), + FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())), + ], + duplicate_treatment: None, + clauses: vec![], + }), + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }) + } +} + +impl VisitorMut for RewriteArrayAnyOperationVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + if let Expr::AnyOp { + left, + compare_op, + right, + .. + } = expr + { + match compare_op { + BinaryOperator::Eq => { + *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref()); + } + BinaryOperator::NotEq => { + *expr = Expr::UnaryOp { + op: UnaryOperator::Not, + expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())), + } + } + _ => {} + } + } + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RewriteArrayAnyOperation { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = RewriteArrayAnyOperationVisitor; + + let _ = s.visit(&mut visitor); + + s + } +} + #[cfg(test)] mod tests { use super::*; @@ -427,4 +501,27 @@ mod tests { "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" ); } + + #[test] + fn test_any_to_array_contains() { + let rules: Vec> = vec![Arc::new(RewriteArrayAnyOperation)]; + + assert_rewrite!( + &rules, + "SELECT a = ANY(current_schemas(true))", + "SELECT array_contains(current_schemas(true), a)" + ); + + assert_rewrite!( + &rules, + "SELECT a != ANY(current_schemas(true))", + "SELECT NOT array_contains(current_schemas(true), a)" + ); + + assert_rewrite!( + &rules, + "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))", + "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)" + ); + } } From 678c74687a3fd1d59366268a69df926968744ad2 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 31 Aug 2025 22:41:41 +0800 Subject: [PATCH 06/13] feat: add more rewrite rules --- datafusion-postgres/src/handlers.rs | 10 ++-- datafusion-postgres/src/sql.rs | 82 ++++++++++++++++++++++++++++ datafusion-postgres/tests/dbeaver.rs | 2 +- 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 028e258..a2984e4 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -3,8 +3,9 @@ use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; use crate::sql::{ - parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes, - ResolveUnqualifiedIdentifer, RewriteArrayAnyOperation, SqlStatementRewriteRule, + parse, rewrite, AliasDuplicatedProjectionRewrite, PrependUnqualifiedTableName, + RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, RewriteArrayAnyOperation, + SqlStatementRewriteRule, }; use async_trait::async_trait; use datafusion::arrow::datatypes::DataType; @@ -82,6 +83,7 @@ impl DfSessionService { Arc::new(ResolveUnqualifiedIdentifer), Arc::new(RemoveUnsupportedTypes::new()), Arc::new(RewriteArrayAnyOperation), + Arc::new(PrependUnqualifiedTableName::new()), ]; let parser = Arc::new(Parser { session_context: session_context.clone(), @@ -297,8 +299,8 @@ impl DfSessionService { Ok(Some(Response::Query(resp))) } "show search_path" => { - let default_catalog = "datafusion"; - let resp = Self::mock_show_response("search_path", default_catalog)?; + let default_schema = "public"; + let resp = Self::mock_show_response("search_path", default_schema)?; Ok(Some(Response::Query(resp))) } _ => Err(PgWireError::UserError(Box::new( diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 596c4a1..457ce21 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -11,6 +11,7 @@ use datafusion::sql::sqlparser::ast::FunctionArgumentList; use datafusion::sql::sqlparser::ast::FunctionArguments; use datafusion::sql::sqlparser::ast::Ident; use datafusion::sql::sqlparser::ast::ObjectName; +use datafusion::sql::sqlparser::ast::ObjectNamePart; use datafusion::sql::sqlparser::ast::OrderByKind; use datafusion::sql::sqlparser::ast::Query; use datafusion::sql::sqlparser::ast::Select; @@ -401,6 +402,63 @@ impl SqlStatementRewriteRule for RewriteArrayAnyOperation { } } +/// Prepend qualifier to table_name +/// +/// Postgres has pg_catalog in search_path by default so it allow access to +/// `pg_namespace` without `pg_catalog.` qualifier +#[derive(Debug)] +pub struct PrependUnqualifiedTableName { + table_names: HashSet, +} + +impl PrependUnqualifiedTableName { + pub fn new() -> Self { + let mut table_names = HashSet::new(); + + table_names.insert("pg_namespace".to_owned()); + + Self { table_names } + } +} + +struct PrependUnqualifiedTableNameVisitor<'a> { + table_names: &'a HashSet, +} + +impl<'a> VisitorMut for PrependUnqualifiedTableNameVisitor<'a> { + type Break = (); + + fn pre_visit_table_factor( + &mut self, + table_factor: &mut TableFactor, + ) -> ControlFlow { + if let TableFactor::Table { name, .. } = table_factor { + if name.0.len() == 1 { + let ObjectNamePart::Identifier(ident) = &name.0[0]; + if self.table_names.contains(&ident.to_string()) { + *name = ObjectName(vec![ + ObjectNamePart::Identifier(Ident::new("pg_catalog")), + name.0[0].clone(), + ]); + } + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for PrependUnqualifiedTableName { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = PrependUnqualifiedTableNameVisitor { + table_names: &self.table_names, + }; + + let _ = s.visit(&mut visitor); + s + } +} + #[cfg(test)] mod tests { use super::*; @@ -524,4 +582,28 @@ mod tests { "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)" ); } + + #[test] + fn test_prepend_unqualified_table_name() { + let rules: Vec> = + vec![Arc::new(PrependUnqualifiedTableName::new())]; + + assert_rewrite!( + &rules, + "SELECT * FROM pg_catalog.pg_namespace", + "SELECT * FROM pg_catalog.pg_namespace" + ); + + assert_rewrite!( + &rules, + "SELECT * FROM pg_namespace", + "SELECT * FROM pg_catalog.pg_namespace" + ); + + assert_rewrite!( + &rules, + "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid", + "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid" + ); + } } diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index 0793fe9..95edfb6 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -13,7 +13,7 @@ const DBEAVER_QUERIES: &[&str] = &[ "SELECT typinput='pg_catalog.array_in'::regproc as is_array, typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN (select ns.oid as nspoid, ns.nspname, r.r from pg_namespace as ns join ( select s.r, (current_schemas(false))[s.r] as nspname from generate_series(1, array_upper(current_schemas(false), 1)) as s(r) ) as r using ( nspname ) ) as sp ON sp.nspoid = typnamespace WHERE pg_type.oid = 1034 ORDER BY sp.r, pg_type.oid DESC", "SHOW search_path", "SELECT db.oid,db.* FROM pg_catalog.pg_database db WHERE datname='postgres'", - "SELECT * FROM pg_catalog.pg_settinngs where name='standard_conforming_strings'", + "SELECT * FROM pg_catalog.pg_settings where name='standard_conforming_strings'", "SELECT string_agg(word, ',' ) from pg_catalog.pg_get_keywords() where word <> ALL ('{a,abs,absolute,action,ada,add,admin,after,all,allocate,alter,aIways,and,any,are,array,as,asc,asenstitive,assertion,assignment,asymmetric,at,atomic,attribute,attributes,authorization,avg,before,begin,bernoulli,between,bigint,binary,blob,boolean,both,breaadth,by,c,call,called,cardinaliity,cascade,cascaded,case,cast,catalog,catalog_name,ceil,ceiling,chain,char,char_length,character,character_length,character_set_catalog,character_set_name,character_set_schema,characteristics,characters,check,checkeed,class_origin,clob,close,coalesce,coboI,code_units,collate,collation,collaition_catalog,collaition_name,collaition_schema,collect,colum,column_name,command_function,command_function_code,commit,committed,condiition,condiition_number,connect,connection_name,constraint,constraint_catalog,constraint_name,constraint_schema,constraints,constructors,contains,continue,convert,corr,correspondiing,count,covar_pop,covar_samp,create,cross,cube,cume_dist,current,current_collation,current_date,current_default_transfom_group,current_path,current_role,current_time,current_timestamp,current_transfom_group_for_type,current_user,cursor,cursor_name,cycle,data,date,datetime_interval_code,datetime_interval_precision,day,deallocate,dec,decimaI,declare,default,defaults,not,null,nullable,nullif,nulls,number,numeric,object,octeet_length,octets,of,old,on,only,open,option,options,or,order,ordering,ordinaliity,others,out,outer,output,over,overlaps,overlay,overriding,pad,parameter,parameter_mode,parameter_name,parameter_ordinal_position,parameter_speciific_catalog,parameter_speciific_name,parameter_speciific_schema,partiaI,partitioon,pascal,path,percent_rank,percentile_cont,percentile_disc,placing,pli,position,power,preceding,precision,prepare,preseerv,primary,prior,privileges,procedure,public,range,rank,read,reads,real,recursivve,ref,references,referencing,regr_avgx,regr_avgy,regr_count,regr_intercept,regr_r2,regr_slope,regr_sxx,regr_sxy,regr_sy y,relative,release,repeatable,restart,result,retun,returned_cardinality,returned_length,returned_octeet_length,returned_sqlstate,returns,revoe,right,role,rollback,rollup,routine,routine_catalog,routine_name,routine_schema,row,row_count,row_number,rows,savepoint,scale,schema,schema_name,scope_catalog,scope_name,scope_schema,scroll,search,second,section,security,select,self,sensitive,sequence,seriializeable,server_name,session,session_user,set,sets,similar,simple,size,smalIint,some,source,space,specifiic,speciific_name,speciifictype,sql,sqlexception,sqlstate,sqlwarning,sqrt,start,state,statement,static,stddev_pop,stddev_samp,structure,style,subclass_origin,submultiset,substring,sum,symmetric,system,system_user,table,table_name,tablesample,temporary,then,ties,time,timesamp,timezone_hour,timezone_minute,to,top_level_count,trailing,transaction,transaction_active,transactions_committed,transactions_rolled_back,transfor,transforms,translate,translation,treat,trigger,trigger_catalog,trigger_name,trigger_schema,trim,true,type,unbounde,undefined,uncommitted,under,union,unique,unknown,unnaamed,unnest,update,upper,usage,user,user_defined_type_catalog,user_defined_type_code,user_defined_type_name,user_defined_type_schema,using,value,values,var_pop,var_samp,varchar,varying,view,when,whenever,where,width_bucket,window,with,within,without,work,write,year,zone", "SELECT version()", "SELECT * FROM pg_catalog.pg_enum WHERE 1<>1 LIMIT 1", From 0af3d80ce33c184580dae6b612c7a4aea2eb23d9 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Mon, 1 Sep 2025 16:26:26 +0800 Subject: [PATCH 07/13] feat: add pg_settings view --- datafusion-postgres/src/pg_catalog.rs | 7 ++ .../src/pg_catalog/pg_settings.rs | 115 ++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 datafusion-postgres/src/pg_catalog/pg_settings.rs diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 426fcac..35573ad 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -24,6 +24,7 @@ mod pg_attribute; mod pg_class; mod pg_database; mod pg_namespace; +mod pg_settings; const PG_CATALOG_TABLE_PG_AGGREGATE: &str = "pg_aggregate"; const PG_CATALOG_TABLE_PG_AM: &str = "pg_am"; @@ -86,6 +87,7 @@ const PG_CATALOG_TABLE_PG_SUBSCRIPTION_REL: &str = "pg_subscription_rel"; const PG_CATALOG_TABLE_PG_TABLESPACE: &str = "pg_tablespace"; const PG_CATALOG_TABLE_PG_TRIGGER: &str = "pg_trigger"; const PG_CATALOG_TABLE_PG_USER_MAPPING: &str = "pg_user_mapping"; +const PG_CATALOG_VIEW_PG_SETTINGS: &str = "pg_settings"; /// Determine PostgreSQL table type (relkind) from DataFusion TableProvider fn get_table_type(table: &Arc) -> &'static str { @@ -180,6 +182,7 @@ pub const PG_CATALOG_TABLES: &[&str] = &[ PG_CATALOG_TABLE_PG_TABLESPACE, PG_CATALOG_TABLE_PG_TRIGGER, PG_CATALOG_TABLE_PG_USER_MAPPING, + PG_CATALOG_VIEW_PG_SETTINGS, ]; #[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord)] @@ -345,6 +348,10 @@ impl SchemaProvider for PgCatalogSchemaProvider { StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), ))) } + PG_CATALOG_VIEW_PG_SETTINGS => { + let table = pg_settings::PgSettingsView::try_new()?; + Ok(Some(Arc::new(table.try_into_memtable()?))) + } _ => Ok(None), } diff --git a/datafusion-postgres/src/pg_catalog/pg_settings.rs b/datafusion-postgres/src/pg_catalog/pg_settings.rs new file mode 100644 index 0000000..c94cd82 --- /dev/null +++ b/datafusion-postgres/src/pg_catalog/pg_settings.rs @@ -0,0 +1,115 @@ +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, BooleanArray, Int32Array, RecordBatch, StringArray}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::catalog::MemTable; +use datafusion::error::Result; + +#[derive(Debug, Clone)] +pub(crate) struct PgSettingsView { + schema: SchemaRef, + data: Vec, +} + +impl PgSettingsView { + pub(crate) fn try_new() -> Result { + let schema = Arc::new(Schema::new(vec![ + // name | setting | unit | category | short_ + //desc | extra_desc + //| context | vartype | source | min_val | max_val | enumvals | + //boot_val | reset_val | sourcefile | sourceline | pending_restart + Field::new("name", DataType::Utf8, true), + Field::new("setting", DataType::Utf8, true), + Field::new("unit", DataType::Utf8, true), + Field::new("category", DataType::Utf8, true), + Field::new("short_desc", DataType::Utf8, true), + Field::new("extra_desc", DataType::Utf8, true), + Field::new("context", DataType::Utf8, true), + Field::new("vartype", DataType::Utf8, true), + Field::new("source", DataType::Utf8, true), + Field::new("min_val", DataType::Utf8, true), + Field::new("max_val", DataType::Utf8, true), + Field::new("enumvals", DataType::Utf8, true), + Field::new("bool_val", DataType::Utf8, true), + Field::new("reset_val", DataType::Utf8, true), + Field::new("sourcefile", DataType::Utf8, true), + Field::new("sourceline", DataType::Int32, true), + Field::new("pending_restart", DataType::Boolean, true), + ])); + + let data = Self::create_data(schema.clone())?; + + Ok(Self { schema, data }) + } + + fn create_data(schema: Arc) -> Result> { + let mut name: Vec> = Vec::new(); + let mut setting: Vec> = Vec::new(); + let mut unit: Vec> = Vec::new(); + let mut category: Vec> = Vec::new(); + let mut short_desc: Vec> = Vec::new(); + let mut extra_desc: Vec> = Vec::new(); + let mut context: Vec> = Vec::new(); + let mut vartype: Vec> = Vec::new(); + let mut source: Vec> = Vec::new(); + let mut min_val: Vec> = Vec::new(); + let mut max_val: Vec> = Vec::new(); + let mut enumvals: Vec> = Vec::new(); + let mut bool_val: Vec> = Vec::new(); + let mut reset_val: Vec> = Vec::new(); + let mut sourcefile: Vec> = Vec::new(); + let mut sourceline: Vec> = Vec::new(); + let mut pending_restart: Vec> = Vec::new(); + + let data = vec![("standard_conforming_strings", "on")]; + + for (setting_name, setting_val) in data { + name.push(Some(setting_name)); + setting.push(Some(setting_val)); + + unit.push(None); + category.push(None); + short_desc.push(None); + extra_desc.push(None); + context.push(None); + vartype.push(None); + source.push(None); + min_val.push(None); + max_val.push(None); + enumvals.push(None); + bool_val.push(None); + reset_val.push(None); + sourcefile.push(None); + sourceline.push(None); + pending_restart.push(None); + } + + let arrays: Vec = vec![ + Arc::new(StringArray::from(name)), + Arc::new(StringArray::from(setting)), + Arc::new(StringArray::from(unit)), + Arc::new(StringArray::from(category)), + Arc::new(StringArray::from(short_desc)), + Arc::new(StringArray::from(extra_desc)), + Arc::new(StringArray::from(context)), + Arc::new(StringArray::from(vartype)), + Arc::new(StringArray::from(source)), + Arc::new(StringArray::from(min_val)), + Arc::new(StringArray::from(max_val)), + Arc::new(StringArray::from(enumvals)), + Arc::new(StringArray::from(bool_val)), + Arc::new(StringArray::from(reset_val)), + Arc::new(StringArray::from(sourcefile)), + Arc::new(Int32Array::from(sourceline)), + Arc::new(BooleanArray::from(pending_restart)), + ]; + + let batch = RecordBatch::try_new(schema.clone(), arrays)?; + + Ok(vec![batch]) + } + + pub(crate) fn try_into_memtable(self) -> Result { + MemTable::try_new(self.schema, vec![self.data]) + } +} From 8859e626b29012d493b501007542c3d301918a21 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Mon, 1 Sep 2025 18:39:16 +0800 Subject: [PATCH 08/13] fix: crash on encoding utf8view list --- arrow-pg/src/list_encoder.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/arrow-pg/src/list_encoder.rs b/arrow-pg/src/list_encoder.rs index d1ca983..49f6373 100644 --- a/arrow-pg/src/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -1,5 +1,6 @@ use std::{str::FromStr, sync::Arc}; +use arrow::array::{BinaryViewArray, StringViewArray}; #[cfg(not(feature = "datafusion"))] use arrow::{ array::{ @@ -150,6 +151,15 @@ pub(crate) fn encode_list( .collect(); encode_field(&value, type_, format) } + DataType::Utf8View => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + encode_field(&value, type_, format) + } DataType::Binary => { let value: Vec> = arr .as_any() @@ -168,6 +178,15 @@ pub(crate) fn encode_list( .collect(); encode_field(&value, type_, format) } + DataType::BinaryView => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + encode_field(&value, type_, format) + } DataType::Date32 => { let value: Vec> = arr From 7c9777b7b708e72521e8d7657c809f84786b762f Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Tue, 2 Sep 2025 22:28:35 +0800 Subject: [PATCH 09/13] feat: add array rewrite --- datafusion-postgres/src/sql.rs | 109 +++++++++++++++++++++++++++ datafusion-postgres/tests/dbeaver.rs | 2 +- 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 457ce21..1b43ea0 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -2,7 +2,11 @@ use std::collections::HashSet; use std::ops::ControlFlow; use std::sync::Arc; +use datafusion::sql::sqlparser::ast::Array; +use datafusion::sql::sqlparser::ast::ArrayElemTypeDef; use datafusion::sql::sqlparser::ast::BinaryOperator; +use datafusion::sql::sqlparser::ast::CastKind; +use datafusion::sql::sqlparser::ast::DataType; use datafusion::sql::sqlparser::ast::Expr; use datafusion::sql::sqlparser::ast::Function; use datafusion::sql::sqlparser::ast::FunctionArg; @@ -23,6 +27,7 @@ use datafusion::sql::sqlparser::ast::TableFactor; use datafusion::sql::sqlparser::ast::TableWithJoins; use datafusion::sql::sqlparser::ast::UnaryOperator; use datafusion::sql::sqlparser::ast::Value; +use datafusion::sql::sqlparser::ast::ValueWithSpan; use datafusion::sql::sqlparser::ast::VisitMut; use datafusion::sql::sqlparser::ast::VisitorMut; use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; @@ -459,6 +464,87 @@ impl SqlStatementRewriteRule for PrependUnqualifiedTableName { } } +#[derive(Debug)] +pub struct FixArrayLiteral; + +struct FixArrayLiteralVisitor; + +impl FixArrayLiteralVisitor { + fn is_string_type(dt: &DataType) -> bool { + matches!( + dt, + DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_) + ) + } +} + +impl VisitorMut for FixArrayLiteralVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + if let Expr::Cast { + kind, + expr, + data_type, + .. + } = expr + { + if kind == &CastKind::DoubleColon { + if let DataType::Array(arr) = data_type { + // cast some to + if let Expr::Value(ValueWithSpan { + value: Value::SingleQuotedString(array_literal), + .. + }) = expr.as_ref() + { + let items = + array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' '); + let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()); + + let is_text = match arr { + ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()), + ArrayElemTypeDef::SquareBracket(dt, _) => { + Self::is_string_type(dt.as_ref()) + } + ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()), + _ => false, + }; + + let elems = items + .map(|s| { + if is_text { + Expr::Value( + Value::SingleQuotedString(s.to_string()).with_empty_span(), + ) + } else { + Expr::Value( + Value::Number(s.to_string(), false).with_empty_span(), + ) + } + }) + .collect(); + *expr = Box::new(Expr::Array(Array { + elem: elems, + named: true, + })); + } + } + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for FixArrayLiteral { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = FixArrayLiteralVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + #[cfg(test)] mod tests { use super::*; @@ -606,4 +692,27 @@ mod tests { "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid" ); } + + #[test] + fn test_array_literal_fix() { + let rules: Vec> = vec![Arc::new(FixArrayLiteral)]; + + assert_rewrite!( + &rules, + "SELECT '{a, abc}'::text[]", + "SELECT ARRAY['a', 'abc']::TEXT[]" + ); + + assert_rewrite!( + &rules, + "SELECT '{1, 2}'::int[]", + "SELECT ARRAY[1, 2]::INT[]" + ); + + assert_rewrite!( + &rules, + "SELECT '{t, f}'::bool[]", + "SELECT ARRAY[t, f]::BOOL[]" + ); + } } diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index 95edfb6..0778b0c 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -14,7 +14,7 @@ const DBEAVER_QUERIES: &[&str] = &[ "SHOW search_path", "SELECT db.oid,db.* FROM pg_catalog.pg_database db WHERE datname='postgres'", "SELECT * FROM pg_catalog.pg_settings where name='standard_conforming_strings'", - "SELECT string_agg(word, ',' ) from pg_catalog.pg_get_keywords() where word <> ALL ('{a,abs,absolute,action,ada,add,admin,after,all,allocate,alter,aIways,and,any,are,array,as,asc,asenstitive,assertion,assignment,asymmetric,at,atomic,attribute,attributes,authorization,avg,before,begin,bernoulli,between,bigint,binary,blob,boolean,both,breaadth,by,c,call,called,cardinaliity,cascade,cascaded,case,cast,catalog,catalog_name,ceil,ceiling,chain,char,char_length,character,character_length,character_set_catalog,character_set_name,character_set_schema,characteristics,characters,check,checkeed,class_origin,clob,close,coalesce,coboI,code_units,collate,collation,collaition_catalog,collaition_name,collaition_schema,collect,colum,column_name,command_function,command_function_code,commit,committed,condiition,condiition_number,connect,connection_name,constraint,constraint_catalog,constraint_name,constraint_schema,constraints,constructors,contains,continue,convert,corr,correspondiing,count,covar_pop,covar_samp,create,cross,cube,cume_dist,current,current_collation,current_date,current_default_transfom_group,current_path,current_role,current_time,current_timestamp,current_transfom_group_for_type,current_user,cursor,cursor_name,cycle,data,date,datetime_interval_code,datetime_interval_precision,day,deallocate,dec,decimaI,declare,default,defaults,not,null,nullable,nullif,nulls,number,numeric,object,octeet_length,octets,of,old,on,only,open,option,options,or,order,ordering,ordinaliity,others,out,outer,output,over,overlaps,overlay,overriding,pad,parameter,parameter_mode,parameter_name,parameter_ordinal_position,parameter_speciific_catalog,parameter_speciific_name,parameter_speciific_schema,partiaI,partitioon,pascal,path,percent_rank,percentile_cont,percentile_disc,placing,pli,position,power,preceding,precision,prepare,preseerv,primary,prior,privileges,procedure,public,range,rank,read,reads,real,recursivve,ref,references,referencing,regr_avgx,regr_avgy,regr_count,regr_intercept,regr_r2,regr_slope,regr_sxx,regr_sxy,regr_sy y,relative,release,repeatable,restart,result,retun,returned_cardinality,returned_length,returned_octeet_length,returned_sqlstate,returns,revoe,right,role,rollback,rollup,routine,routine_catalog,routine_name,routine_schema,row,row_count,row_number,rows,savepoint,scale,schema,schema_name,scope_catalog,scope_name,scope_schema,scroll,search,second,section,security,select,self,sensitive,sequence,seriializeable,server_name,session,session_user,set,sets,similar,simple,size,smalIint,some,source,space,specifiic,speciific_name,speciifictype,sql,sqlexception,sqlstate,sqlwarning,sqrt,start,state,statement,static,stddev_pop,stddev_samp,structure,style,subclass_origin,submultiset,substring,sum,symmetric,system,system_user,table,table_name,tablesample,temporary,then,ties,time,timesamp,timezone_hour,timezone_minute,to,top_level_count,trailing,transaction,transaction_active,transactions_committed,transactions_rolled_back,transfor,transforms,translate,translation,treat,trigger,trigger_catalog,trigger_name,trigger_schema,trim,true,type,unbounde,undefined,uncommitted,under,union,unique,unknown,unnaamed,unnest,update,upper,usage,user,user_defined_type_catalog,user_defined_type_code,user_defined_type_name,user_defined_type_schema,using,value,values,var_pop,var_samp,varchar,varying,view,when,whenever,where,width_bucket,window,with,within,without,work,write,year,zone", + "SELECT string_agg(word, ',' ) from pg_catalog.pg_get_keywords() where word <> ALL ('{a,abs,absolute,action,ada,add,admin,after,all,allocate,alter,aIways,and,any,are,array,as,asc,asenstitive,assertion,assignment,asymmetric,at,atomic,attribute,attributes,authorization,avg,before,begin,bernoulli,between,bigint,binary,blob,boolean,both,breaadth,by,c,call,called,cardinaliity,cascade,cascaded,case,cast,catalog,catalog_name,ceil,ceiling,chain,char,char_length,character,character_length,character_set_catalog,character_set_name,character_set_schema,characteristics,characters,check,checkeed,class_origin,clob,close,coalesce,coboI,code_units,collate,collation,collaition_catalog,collaition_name,collaition_schema,collect,colum,column_name,command_function,command_function_code,commit,committed,condiition,condiition_number,connect,connection_name,constraint,constraint_catalog,constraint_name,constraint_schema,constraints,constructors,contains,continue,convert,corr,correspondiing,count,covar_pop,covar_samp,create,cross,cube,cume_dist,current,current_collation,current_date,current_default_transfom_group,current_path,current_role,current_time,current_timestamp,current_transfom_group_for_type,current_user,cursor,cursor_name,cycle,data,date,datetime_interval_code,datetime_interval_precision,day,deallocate,dec,decimaI,declare,default,defaults,not,null,nullable,nullif,nulls,number,numeric,object,octeet_length,octets,of,old,on,only,open,option,options,or,order,ordering,ordinaliity,others,out,outer,output,over,overlaps,overlay,overriding,pad,parameter,parameter_mode,parameter_name,parameter_ordinal_position,parameter_speciific_catalog,parameter_speciific_name,parameter_speciific_schema,partiaI,partitioon,pascal,path,percent_rank,percentile_cont,percentile_disc,placing,pli,position,power,preceding,precision,prepare,preseerv,primary,prior,privileges,procedure,public,range,rank,read,reads,real,recursivve,ref,references,referencing,regr_avgx,regr_avgy,regr_count,regr_intercept,regr_r2,regr_slope,regr_sxx,regr_sxy,regr_sy y,relative,release,repeatable,restart,result,retun,returned_cardinality,returned_length,returned_octeet_length,returned_sqlstate,returns,revoe,right,role,rollback,rollup,routine,routine_catalog,routine_name,routine_schema,row,row_count,row_number,rows,savepoint,scale,schema,schema_name,scope_catalog,scope_name,scope_schema,scroll,search,second,section,security,select,self,sensitive,sequence,seriializeable,server_name,session,session_user,set,sets,similar,simple,size,smalIint,some,source,space,specifiic,speciific_name,speciifictype,sql,sqlexception,sqlstate,sqlwarning,sqrt,start,state,statement,static,stddev_pop,stddev_samp,structure,style,subclass_origin,submultiset,substring,sum,symmetric,system,system_user,table,table_name,tablesample,temporary,then,ties,time,timesamp,timezone_hour,timezone_minute,to,top_level_count,trailing,transaction,transaction_active,transactions_committed,transactions_rolled_back,transfor,transforms,translate,translation,treat,trigger,trigger_catalog,trigger_name,trigger_schema,trim,true,type,unbounde,undefined,uncommitted,under,union,unique,unknown,unnaamed,unnest,update,upper,usage,user,user_defined_type_catalog,user_defined_type_code,user_defined_type_name,user_defined_type_schema,using,value,values,var_pop,var_samp,varchar,varying,view,when,whenever,where,width_bucket,window,with,within,without,work,write,year,zone}'::text[])", "SELECT version()", "SELECT * FROM pg_catalog.pg_enum WHERE 1<>1 LIMIT 1", "SELECT reltype FROM pg_catalog.pg_class WHERE 1<>1 LIMIT 1", From b0c63448bd9814e235be886da4c7e64e65034a00 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 3 Sep 2025 19:28:56 +0800 Subject: [PATCH 10/13] feat: implement more about pg_get_keywords --- datafusion-postgres/src/handlers.rs | 8 +- datafusion-postgres/src/pg_catalog.rs | 45 ++++- datafusion-postgres/src/sql.rs | 50 +++++ flake.nix | 3 +- .../pg_get_keywords.feather | Bin 0 -> 13602 bytes pg_to_arrow.py | 181 ++++++++++++++++++ 6 files changed, 275 insertions(+), 12 deletions(-) create mode 100644 pg_catalog_arrow_exports/pg_get_keywords.feather create mode 100644 pg_to_arrow.py diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index a2984e4..6bb121c 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -3,9 +3,9 @@ use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; use crate::sql::{ - parse, rewrite, AliasDuplicatedProjectionRewrite, PrependUnqualifiedTableName, - RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, RewriteArrayAnyOperation, - SqlStatementRewriteRule, + parse, rewrite, AliasDuplicatedProjectionRewrite, FixArrayLiteral, PrependUnqualifiedTableName, + RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, + RewriteArrayAnyOperation, SqlStatementRewriteRule, }; use async_trait::async_trait; use datafusion::arrow::datatypes::DataType; @@ -84,6 +84,8 @@ impl DfSessionService { Arc::new(RemoveUnsupportedTypes::new()), Arc::new(RewriteArrayAnyOperation), Arc::new(PrependUnqualifiedTableName::new()), + Arc::new(FixArrayLiteral), + Arc::new(RemoveTableFunctionQualifier), ]; let parser = Arc::new(Parser { session_context: session_context.clone(), diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 35573ad..96a0d18 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -10,13 +10,13 @@ use datafusion::arrow::array::{ use datafusion::arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion::arrow::ipc::reader::FileReader; use datafusion::catalog::streaming::StreamingTable; -use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider}; +use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider, TableFunctionImpl}; use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::datasource::{TableProvider, ViewTable}; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility}; use datafusion::physical_plan::streaming::PartitionStream; -use datafusion::prelude::{create_udf, SessionContext}; +use datafusion::prelude::{create_udf, Expr, SessionContext}; use postgres_types::Oid; use tokio::sync::RwLock; @@ -199,7 +199,7 @@ pub struct PgCatalogSchemaProvider { catalog_list: Arc, oid_counter: Arc, oid_cache: Arc>>, - static_tables: PgCatalogStaticTables, + static_tables: Arc, } #[async_trait] @@ -363,12 +363,15 @@ impl SchemaProvider for PgCatalogSchemaProvider { } impl PgCatalogSchemaProvider { - pub fn try_new(catalog_list: Arc) -> Result { + pub fn try_new( + catalog_list: Arc, + static_tables: Arc, + ) -> Result { Ok(Self { catalog_list, oid_counter: Arc::new(AtomicU32::new(16384)), oid_cache: Arc::new(RwLock::new(HashMap::new())), - static_tables: PgCatalogStaticTables::try_new()?, + static_tables, }) } } @@ -406,10 +409,17 @@ impl ArrowTable { } } +impl TableFunctionImpl for ArrowTable { + fn call(&self, _args: &[Expr]) -> Result> { + let table = self.clone().try_into_memtable()?; + Ok(Arc::new(table)) + } +} + /// pg_catalog table as datafusion table provider /// /// This implementation only contains static tables -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct PgCatalogStaticTables { pub pg_aggregate: Arc, pub pg_am: Arc, @@ -468,6 +478,8 @@ pub struct PgCatalogStaticTables { pub pg_tablespace: Arc, pub pg_trigger: Arc, pub pg_user_mapping: Arc, + + pub pg_get_keywords: Arc, } impl PgCatalogStaticTables { @@ -654,6 +666,10 @@ impl PgCatalogStaticTables { pg_user_mapping: Self::create_arrow_table( include_bytes!("../../pg_catalog_arrow_exports/pg_user_mapping.feather").to_vec(), )?, + + pg_get_keywords: Self::create_arrow_table_function( + include_bytes!("../../pg_catalog_arrow_exports/pg_get_keywords.feather").to_vec(), + )?, }) } @@ -663,6 +679,11 @@ impl PgCatalogStaticTables { let mem_table = table.try_into_memtable()?; Ok(Arc::new(mem_table)) } + + fn create_arrow_table_function(data_bytes: Vec) -> Result> { + let table = ArrowTable::from_ipc_data(data_bytes)?; + Ok(Arc::new(table)) + } } pub fn create_current_schemas_udf() -> ScalarUDF { @@ -901,8 +922,11 @@ pub fn setup_pg_catalog( session_context: &SessionContext, catalog_name: &str, ) -> Result<(), Box> { - let pg_catalog = - PgCatalogSchemaProvider::try_new(session_context.state().catalog_list().clone())?; + let static_tables = Arc::new(PgCatalogStaticTables::try_new()?); + let pg_catalog = PgCatalogSchemaProvider::try_new( + session_context.state().catalog_list().clone(), + static_tables.clone(), + )?; session_context .catalog(catalog_name) .ok_or_else(|| { @@ -920,6 +944,7 @@ pub fn setup_pg_catalog( session_context.register_udf(create_pg_table_is_visible()); session_context.register_udf(create_format_type_udf()); session_context.register_udf(create_session_user_udf()); + session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone()); Ok(()) } @@ -1173,5 +1198,9 @@ mod test { include_bytes!("../../pg_catalog_arrow_exports/pg_user_mapping.feather").to_vec(), ) .expect("Failed to load ipc data"); + let _ = ArrowTable::from_ipc_data( + include_bytes!("../../pg_catalog_arrow_exports/pg_get_keywords.feather").to_vec(), + ) + .expect("Failed to load ipc data"); } } diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 1b43ea0..afd135d 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -545,6 +545,44 @@ impl SqlStatementRewriteRule for FixArrayLiteral { } } +/// Remove qualifier from table function +/// +/// The query engine doesn't support qualified table function name +#[derive(Debug)] +pub struct RemoveTableFunctionQualifier; + +struct RemoveTableFunctionQualifierVisitor; + +impl VisitorMut for RemoveTableFunctionQualifierVisitor { + type Break = (); + + fn pre_visit_table_factor( + &mut self, + table_factor: &mut TableFactor, + ) -> ControlFlow { + if let TableFactor::Table { name, args, .. } = table_factor { + if args.is_some() { + // multiple idents in name, which means it's a qualified table name + if name.0.len() > 1 { + if let Some(last_ident) = name.0.pop() { + *name = ObjectName(vec![last_ident]); + } + } + } + } + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RemoveTableFunctionQualifier { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = RemoveTableFunctionQualifierVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + #[cfg(test)] mod tests { use super::*; @@ -715,4 +753,16 @@ mod tests { "SELECT ARRAY[t, f]::BOOL[]" ); } + + #[test] + fn test_remove_qualifier_from_table_function() { + let rules: Vec> = + vec![Arc::new(RemoveTableFunctionQualifier)]; + + assert_rewrite!( + &rules, + "SELECT * FROM pg_catalog.pg_get_keywords()", + "SELECT * FROM pg_get_keywords()" + ); + } } diff --git a/flake.nix b/flake.nix index 44ac8bb..bda9fbd 100644 --- a/flake.nix +++ b/flake.nix @@ -15,7 +15,8 @@ let pkgs = nixpkgs.legacyPackages.${system}; pythonEnv = pkgs.python3.withPackages (ps: with ps; [ - psycopg + psycopg2-binary + pyarrow ]); buildInputs = with pkgs; [ llvmPackages.libclang diff --git a/pg_catalog_arrow_exports/pg_get_keywords.feather b/pg_catalog_arrow_exports/pg_get_keywords.feather new file mode 100644 index 0000000000000000000000000000000000000000..17099bd1de89e89d336e16ec64927f41b926cbea GIT binary patch literal 13602 zcmeI(d3@B=-uUs8ndx)^nih(zLR)r3XxU}6r6?#0rM7~IKsu9YLuZm9$#g-5qOvHc zAPOo92r7sopeQI_HaA4PD!2pUhKg*(1+Q1M&*u!_&)4(&-QPdYzt6PqW|GMyC+D1; z^UX<`UsN=DY_1UXygDmFqzU;-5epTBf@VO9su>h0N8$OI< zTYM=&ZmaeAWlZAVE7fd0s5=2Ue+}#1@wauGvrK+pz7i48-1w?4>pXLAJ^z2-W-{x^ zZHBc{J{6D7q zWd`y3(SqjU#{^#oqy$gzWB1QZ9X~gHfC$( zX63FLpb1)_EwUj~-3@)vAA^vO;TVMyjKhtXj8f>ZP>tI#ANSyXJd9P?fGyaDSFs20 z;Uj#CFYpbH;RJrf8JvZeENTE1P0$i;(GJ}(5P7&7qfv|-5x^8Us6j2};9fj{Wq1lN zUk*<6($N&HkcoEagzgxOt5AXfEYxB??!yu+#VS0B4R{_eV<&dw zJ$#0*a1^KTFLkkBZ5lQU?yhaE-b(USc;W+ z8k?{cJFy%4Z~$N92b{t$_%~!bkc38Pj8oEyBreG>+Fdefo zAB*r1mSYuGV?8$FMQq1zyoZnQ2@d0XoW`#>3ojKu1&z@PS?GvMkb{B9#|Vr>5pF~P zWiSy%C1&7u%)vq|!&BIVUD%6{aRjIE8~y<8qVORZ>9`0jk%i94MPCfS6S|`(`e87xz$lExM1)`>hH1D1_h1Pg!x}t~E!c)v@h0BJ zejLPM9L3N09Ul7R6r`gWvd|el&>I6$fU8i18!!n$lp~5t+=^M4i}_f9hp-$^;2AuJ zE!c)PuooZT5RT#$&cH+Wn1B?dq6wOz4cenKa?l@xFceo{BuX$I0YqS<8Z$5p^RN($ zu@XDxvNBsQ(^w-M*20dZg`0|<;VP}h(M&B|cWqs)v{i8|rbL?XIt8-%NxYeka8 z3Rer1Yc7jQ7l_)JTW;CLG_Aj=>sP7;DylR)SgMCCTQAki46{^st90EgHF)oql1jU# z)No5JD@-CvEw{YX*0ms$ngCnX1Ekw0a`a$esc07@wL+D_XPP@#>LwB2(N%aXY!j`&h&; z@|=JbNl?6rLQzG!kt*IdT$jCDk*G}$_vQG_K+Lvv(+$@IEHh}NDyy_G??Ogd%pNZ` znxU`}a9LpLJCrC>$F;TP@@yQ#woTU+yG++GW9;3m)QPeeMA?wz25fz!c;2>1sZ$w< zmFhPOQjsyKo#I(2GA9MQ3vWZw$F{;$k9JD?h}K8+ag7M^Z9`G)xLPE7F~=LDv`C+0 z*@2osSPxQ0K2cjrc?PwbpyT-|)uPwkivPR)m5JS}|AT*;Z87JuAYFu1BJ7jcyJSU+9!a&@juG+BMPH zqAtS>>L!&aK+?@vgk({n*fd|Qq%8F6fFAXWq2XA(jtrq^iiN|WnAu(NC2|hIPAKt7J(zrv>k=aIoBp z*#_C*8g7jdGQviMZg5Hl4b4p!nK3t@oc0=#DAkZNI1;6?FigW7tkCJ#vLQ|wTVJoF z88<76U!+$jh()?%nB}@{Q2uqPyyttIcQu9?CbJA`vTc}@m0^;)G-slcCtC)Z%L)#O zrX^^M7dft8)=lKmqcbKa>_SUknCd?I~P zZh}aa9ab)FRU_f~UQrr|GOKio9&j57UwbiTYD~9lnu}73o2uG?{IZQhijpdnPKuHu z>eAU!Os~-OXoX(GuM@vG4r^vvOe^C`*-j9dj=Ra1Z*x$GL|t+jZB0egwyY3u95<|o zP^pLSR+J*|4wC5VCZ!wsm5Mzm;@ ztc_^Y8NSd&WVqW!K8ve2iJ{b1V#~&pmir57ILoaS{!(frEp!KQT5xe5p;ai4Skz-J z;uBepX(x2;Ma-_fE@Ly;5KY7#}J zZA5*A7O4(WsqLutwvrZ361ikTx=42jLyvnDT69-LqcqT@nCj=1BM(u{Ess*lR*>^G zAlo5IleeZ5wW!uLQ7gq8wM?4ifId~0Oqy|v&tW!7Q(y4%Y)TWK21z}S* ze=3czUZ%66MLV)Y@ui5SbQ$7eQOl_qiN;F9Et6)_8BC2?E{8eU;RG$yKTj;APYT-; zlk(^T?T}w|PRO%4yA%AyG*G%j-6iqOMZQyR>xwVa*5`RecOq~Pe`jixKh31p)5`~3 z=%y3Si;c8ibfOfjM$EjD-jo>%oGx9NC_A~Hj;;Ub%_%E)nW5`tO6J9)_Gn)d z*=;M1ZB_B>*fDxDtx_MIU+ ze5Vr~lj1J7T*r#&)B{>)XKL8#s&KmfD>*Z4jrC5|Y?otnIS0w;zKMka$8m!_5;n-T z-*GJh6pkfZ4>CYQvduETm>G3qoa?rcJgOtLcTN7m~*-q@+ z#1~=&E-^Z&V9Y$8%y>(}$H^UHX1Amt6JjQ{i9a-9F|&e`kQp-}J7x|^h{kBZ!?Eb> zWbwX|7JJn@xl#+nVv*_obWWFEbka#8&AQ*87E2X*G2N-W##5K9(Wu%64Q_xb?Q}w5zDHX~uNBijF|9j)eK`qlhmXh4JA`&MhHF1m$>_9A<2ai_i?sA;-Yvh@cz= zB;B0(ToXZ#i)A7U*=UFM=zxysgwE)KuDArGd_41yJLIr5N?p(wyGT!AYw93yZQuEt1QgKIGgqfv-4D1sc$EQTD%yB=dP z4mV&t6jUIL2uxUrVk&Gna1lc#s!)v@OvBB% z1-D{4W}p@`aT{*O9himLn1i{v6L(=A=HqVMgL|<63vnM7;eI@T#aMy|@em%yQapl3 zu?)+x0*_%OR^f3xfhVyVYw#4-;%PjCb$Ax*u>l+L9G=G}ynq+68C&oYUdC3uf`4Ee zwqpl&Vi#V;Yj_=R;7#ntTiAoQu@~>)pLiGV;eG7Getdur@ew}8C-@Yf;Q$We5I)CY ze1R|V6~4wdID&8S9lpm={D5OPjuSYEQ>epf{D`0MGk(Dt{EC0!H~bsF<1GHbpQwjW z_=N{@EZBzx_>l;C`yd%9kmJM+(Fm!KW5#JnM`OtG;|w&#MQDcRXaPB{+zN6?{$jL& z9B0l%7P2A7oZF)VI-(OgqYJv?638*>OVJ%Ykb_+GL@)FP?WgFAe&~!L=9#IUZh!F(|@ykmKVexE^CM4mV&tVOMm&e-u?a8WMQp|vyo8sr6|dkQ*oN)cft}ce zSMeHN#~XMPyYUwG;BD;1JNPHw#d~-k`>-D$;6r?bkMRjU#b-ExgE)lGaTs6VOMHc| z@ePjPTYQJ_aTGt`7>?rvPU5fEg_-l?Rk?hYQKE>VqLPx5B4$!#U%a50xq^b?f`ZoZ z?!J}CE-1KI^eQMSZYKL$Plm{0^+uwwsHmv8sGD3#$r0YJ!qZWBTM8vZ#(kbz#>g}a^?4W7L^qD7f-T~H^n8z1tcO%cHCoFv$*70;k{WX zY`(u-xp6^3d|&GXE0+`%w3R)B(wlZjcE!p?!sn6kf~Si-5c$5HNJW6`dtXn<_a(hFwEmc`dGehtuCcpA$}<>@Zk7ZqiQrX{IL zVL<^oBQtZz`5Zl;kAs-audn~D-p2*{jQSbe&8WGjCqN@AX!F{i*H0FC%{C zP5t(EVuPF>uHU}BuD;*-t^LQii(-k#OKq5TMo-9nUWr1CM=8RnL@nlG0UpE(tic9s z!7l8D?0aPYBKr>6FUU6EPYGO#e#pm2RACil-9CeC^0q5_V+cl|2on%O6xFE3UAP}R za2#jwC%Alw-$!fBv?JubGI^hDCho){Jc`xW2zejubsWbT2p^@4475f^$a`D^FbtzG z4wEql3$YYWU>#n-KkyBX<7dcwRDLQ`8d{<~x}!gc`o#h~jFor>@?O(6?8ZJE#J8x! zzrj_xXb5?aDHC1M6Y}2DFpR=Dlp+il(=iwFKGGw29M574c405%{iDwz?;HIHdC$n7 z$hI*ESE3N(A@A?pg9ot!Yq1I2@B@CvS@@IatI-_UxC&!XjoUCEi?IUoKFo9Y9N(gG zGG%~{=#KsvifeE^CSwj3Vkw@$dTho{yp0cW81g>JPxu3gRD2aJkcTiTQH#59KX%|a zenw^k&JXm&4XA~@_plL%klc_y9~$n)OE`k`M)dVC@Bo(MDLjvV;7z=b1Na)p@C&k1 zsrx8G4Q65<9>8*}#(KPjSMd%$#+NvTUr>)^m6!;Z<0@Q-n^1FSg<Z#uCVnj#+~QH7b9hX=46tFaMV@jBkYhd7L*_zAzm zOHF|>CSe{P zz;Zl=zg`z+&c_S#Iga%)Qw}mR3K~oigVk}O36Y1IBraB;6%Bl&gxXOw@V61Z9C0Qs zvvWW*GfVYMZdJr@B4%24rX1qRWK3G`k}Dg@2Fh@8Ib+IKi*&V{P}N688}&M&`nL-? z>zg6`z2$0M*fc|n>NTQ)x{QzW7t2$g7BvZHN?sjQ^&)LbrEX9{(ylEIqwPx{+>;+Z$bBeQ)d8*lDWeM5xORj8*tiT<19 zH!@@{UnUbS<4?I_`*@|Hyh5Djs>hCOzI9WX?E~UaLBZN9b>7Wc2s-|T4rV9UI`JLO<4@4hd zoy-#dNm(Gu6rtX%bnsJ9GP5XN_M%h?^@VtTtG+4m?6ZX;kA$r7cZ*l+zLdVK7Do=vc58^v07wRg{WxgBY`&=y-$(_gZsf)6cy6oxeze{k!)=%>< zR#yC<)A+Bm{XfR>bd~Y(%U(=}p#&jRVJ;rRT5QE$9KuP+@s~8>?kseJjLk>j226px zUY`$ny)3VH<@Kt(-jvsiKf+7Aoq-PMi{Tgt1Jf`Mk6<0P;az+IhKGfRwyFu*qYv^h z5)&YwWt)zBu>vwq--Z472J+c7KQVPPbjD@48aIOH2E?tn2lBbH4cLi&_!=@EPauZA z2%XRq1tvR$?8t;az-&V>pXO#LDuypDvKk`doqQQHE;ViHGqtXAZR-3lGi4NeKSIXe z8R&q%7>;o;Fb%V?7^|@bui^v9c>D|!iJ_aL9eU#mT#qu?xE1$cIW}M?_Tg*%j0EE3 zi_i)EaTUg+0=M99EW>*2fQ+>d;spMLN^G2oTnxo^2qFdr; zk%N2`ApjS%uo$bc1#jUq9K%^868APo7Yx8iOh5$FaW7Ut#>=~~AK%~?_=$C8tlSwg zP9BC5giwXKcnE8;6?<_ACsB_yV%#kBL;;GSqY`rpXOj2N~-59DDCO5xxRJb)*$8N09_-{7y;g_-kld;)!Iz2|&v zF8}W%GyZH#y7*(W%}soR06#OIha=UF${%W)av(9;lgOM;IiQ{t2h{JBU)2#A+M^_l7EHi9<$wRbN9NPH|L-Xzv-0Er G8}(m&QG28S literal 0 HcmV?d00001 diff --git a/pg_to_arrow.py b/pg_to_arrow.py new file mode 100644 index 0000000..a364a34 --- /dev/null +++ b/pg_to_arrow.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" +Export PostgreSQL query results to Arrow IPC Feather format. +Minimal dependencies: psycopg2, pyarrow +""" + +import argparse +import psycopg2 +import pyarrow as pa +import pyarrow.feather as feather +from psycopg2.extras import RealDictCursor +from typing import Dict, Any, List, Optional +import sys + +def map_postgresql_to_arrow_type(type_oid: int) -> pa.DataType: + """Map PostgreSQL data types to Arrow data types.""" + # Map OIDs to Arrow types + type_mapping = { + # Integer types (OIDs from PostgreSQL documentation) + 20: pa.int64(), # int8 (bigint) + 21: pa.int16(), # int2 (smallint) + 23: pa.int32(), # int4 (integer) + 26: pa.int32(), # oid + + # Floating point types + 700: pa.float32(), # float4 (real) + 701: pa.float64(), # float8 (double precision) + 1700: pa.float64(), # numeric (decimal) + + # Boolean + 16: pa.bool_(), # bool + + # String types + 25: pa.string(), # text + 1043: pa.string(), # varchar + 18: pa.string(), # char + 19: pa.string(), # name + + # Date/time types + 1082: pa.date32(), # date + 1114: pa.timestamp('us'), # timestamp without time zone + 1184: pa.timestamp('us', tz='UTC'), # timestamp with time zone + 1083: pa.time64('us'), # time without time zone + 1266: pa.time64('us'), # time with time zone + + # Binary data + 17: pa.binary(), # bytea + + # JSON types + 114: pa.string(), # json + 3802: pa.string(), # jsonb + + # UUID + 2950: pa.string(), # uuid (Arrow doesn't have native UUID type) + + # Network types + 869: pa.string(), # inet + 650: pa.string(), # cidr + 829: pa.string(), # macaddr + } + + return type_mapping.get(type_oid, pa.string()) # Fallback to string + +def export_query_to_feather( + connection_string: str, + query: str, + output_file: str, + batch_size: int = 10000 +) -> None: + """Execute PostgreSQL query and export results to Arrow Feather format.""" + + try: + # Connect to PostgreSQL + conn = psycopg2.connect(connection_string) + cursor = conn.cursor(cursor_factory=RealDictCursor) + + # Execute query + cursor.execute(query) + + # Get column information + columns = [] + arrow_types = [] + column_names = [] + + for desc in cursor.description: + col_name = desc.name + col_oid = desc.type_code + + arrow_type = map_postgresql_to_arrow_type(col_oid) + + columns.append(col_name) + arrow_types.append(arrow_type) + column_names.append(col_name) + + # Process data in batches + all_data = {col: [] for col in columns} + rows_processed = 0 + + while True: + batch = cursor.fetchmany(batch_size) + if not batch: + break + + for row in batch: + for col in columns: + all_data[col].append(row[col]) + + rows_processed += len(batch) + print(f"Processed {rows_processed} rows...", end='\r') + + print(f"\nTotal rows processed: {rows_processed}") + + if rows_processed > 0: + # Convert to Arrow Table + arrays = [] + for col, arrow_type in zip(columns, arrow_types): + try: + array = pa.array(all_data[col], type=arrow_type) + except (pa.ArrowInvalid, pa.ArrowTypeError) as e: + print(f"Warning: Could not convert column '{col}' to {arrow_type}: {e}") + print("Falling back to string type") + array = pa.array([str(x) if x is not None else None for x in all_data[col]], type=pa.string()) + arrays.append(array) + + # Create table and write to feather + table = pa.Table.from_arrays(arrays, names=column_names) + feather.write_feather(table, output_file) + + print(f"Successfully exported {rows_processed} rows to {output_file}") + print(f"Schema: {table.schema}") + else: + print("No data found for the query.") + + except psycopg2.Error as e: + print(f"PostgreSQL error: {e}") + sys.exit(1) + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + finally: + if 'cursor' in locals(): + cursor.close() + if 'conn' in locals(): + conn.close() + +def main(): + parser = argparse.ArgumentParser(description='Export PostgreSQL query to Arrow Feather format') + + # Connection options + parser.add_argument('--host', default='localhost', help='PostgreSQL host') + parser.add_argument('--port', type=int, default=5432, help='PostgreSQL port') + parser.add_argument('--database', default='postgres', help='Database name') + parser.add_argument('--user', default='postgres', help='Database user') + parser.add_argument('--password', default='', help='Database password') + + # Alternative: connection string + parser.add_argument('--connection-string', help='PostgreSQL connection string (overrides individual connection params)') + + parser.add_argument('--query', required=True, help='SQL query to execute') + parser.add_argument('--output', required=True, help='Output feather file path') + parser.add_argument('--batch-size', type=int, default=10000, help='Batch size for processing') + + args = parser.parse_args() + + # Build connection string + if args.connection_string: + connection_string = args.connection_string + else: + connection_string = f"host={args.host} port={args.port} dbname={args.database} user={args.user} password={args.password}" + + export_query_to_feather( + connection_string=connection_string, + query=args.query, + output_file=args.output, + batch_size=args.batch_size + ) + +if __name__ == "__main__": + main() From 0543e7a4613b95bf5e1e08a0f9f505f160ada7c4 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 3 Sep 2025 22:49:15 +0800 Subject: [PATCH 11/13] feat: add support for ALLop --- datafusion-postgres/src/handlers.rs | 4 +-- datafusion-postgres/src/sql.rs | 49 +++++++++++++++++++---------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 6bb121c..7d2ccb3 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -5,7 +5,7 @@ use crate::auth::{AuthManager, Permission, ResourceType}; use crate::sql::{ parse, rewrite, AliasDuplicatedProjectionRewrite, FixArrayLiteral, PrependUnqualifiedTableName, RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, - RewriteArrayAnyOperation, SqlStatementRewriteRule, + RewriteArrayAnyAllOperation, SqlStatementRewriteRule, }; use async_trait::async_trait; use datafusion::arrow::datatypes::DataType; @@ -82,7 +82,7 @@ impl DfSessionService { Arc::new(AliasDuplicatedProjectionRewrite), Arc::new(ResolveUnqualifiedIdentifer), Arc::new(RemoveUnsupportedTypes::new()), - Arc::new(RewriteArrayAnyOperation), + Arc::new(RewriteArrayAnyAllOperation), Arc::new(PrependUnqualifiedTableName::new()), Arc::new(FixArrayLiteral), Arc::new(RemoveTableFunctionQualifier), diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index afd135d..65c8021 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -343,11 +343,11 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes { /// Rewrite Postgres's ANY operator to array_contains #[derive(Debug)] -pub struct RewriteArrayAnyOperation; +pub struct RewriteArrayAnyAllOperation; -struct RewriteArrayAnyOperationVisitor; +struct RewriteArrayAnyAllOperationVisitor; -impl RewriteArrayAnyOperationVisitor { +impl RewriteArrayAnyAllOperationVisitor { fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr { Expr::Function(Function { name: ObjectName::from(vec![Ident::new("array_contains")]), @@ -369,21 +369,33 @@ impl RewriteArrayAnyOperationVisitor { } } -impl VisitorMut for RewriteArrayAnyOperationVisitor { +impl VisitorMut for RewriteArrayAnyAllOperationVisitor { type Break = (); fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { - if let Expr::AnyOp { - left, - compare_op, - right, - .. - } = expr - { - match compare_op { + match expr { + Expr::AnyOp { + left, + compare_op, + right, + .. + } => match compare_op { BinaryOperator::Eq => { *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref()); } + BinaryOperator::NotEq => { + // TODO:left not equals to any element in array + } + _ => {} + }, + Expr::AllOp { + left, + compare_op, + right, + } => match compare_op { + BinaryOperator::Eq => { + // TODO: left equals to every element in array + } BinaryOperator::NotEq => { *expr = Expr::UnaryOp { op: UnaryOperator::Not, @@ -391,15 +403,17 @@ impl VisitorMut for RewriteArrayAnyOperationVisitor { } } _ => {} - } + }, + _ => {} } + ControlFlow::Continue(()) } } -impl SqlStatementRewriteRule for RewriteArrayAnyOperation { +impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation { fn rewrite(&self, mut s: Statement) -> Statement { - let mut visitor = RewriteArrayAnyOperationVisitor; + let mut visitor = RewriteArrayAnyAllOperationVisitor; let _ = s.visit(&mut visitor); @@ -686,7 +700,8 @@ mod tests { #[test] fn test_any_to_array_contains() { - let rules: Vec> = vec![Arc::new(RewriteArrayAnyOperation)]; + let rules: Vec> = + vec![Arc::new(RewriteArrayAnyAllOperation)]; assert_rewrite!( &rules, @@ -696,7 +711,7 @@ mod tests { assert_rewrite!( &rules, - "SELECT a != ANY(current_schemas(true))", + "SELECT a <> ALL(current_schemas(true))", "SELECT NOT array_contains(current_schemas(true), a)" ); From fe43ea820b83c4391fa5be904fd0889688856696 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Thu, 4 Sep 2025 16:12:04 +0800 Subject: [PATCH 12/13] feat: add remaining udfs for dbeaver startup queries --- datafusion-postgres/src/pg_catalog.rs | 55 ++++++++++++++++++- .../src/pg_catalog/pg_class.rs | 4 ++ datafusion-postgres/tests/dbeaver.rs | 4 ++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 96a0d18..6701d58 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -890,7 +890,7 @@ pub fn create_format_type_udf() -> ScalarUDF { create_udf( "format_type", - vec![DataType::Int32, DataType::Int32], + vec![DataType::Int64, DataType::Int32], DataType::Utf8, Volatility::Stable, Arc::new(func), @@ -917,6 +917,57 @@ pub fn create_session_user_udf() -> ScalarUDF { ) } +pub fn create_pg_get_expr_udf() -> ScalarUDF { + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let expr = &args[0]; + let _oid = &args[1]; + + // For now, always return true (full access for current user) + let mut builder = StringBuilder::new(); + for _ in 0..expr.len() { + builder.append_value(""); + } + + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + create_udf( + "pg_catalog.pg_get_expr", + vec![DataType::Utf8, DataType::Int32], + DataType::Utf8, + Volatility::Stable, + Arc::new(func), + ) +} + +pub fn create_pg_get_partkeydef_udf() -> ScalarUDF { + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let oid = &args[0]; + + // For now, always return true (full access for current user) + let mut builder = StringBuilder::new(); + for _ in 0..oid.len() { + builder.append_value(""); + } + + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + create_udf( + "pg_catalog.pg_get_partkeydef", + vec![DataType::Utf8], + DataType::Utf8, + Volatility::Stable, + Arc::new(func), + ) +} + /// Install pg_catalog and postgres UDFs to current `SessionContext` pub fn setup_pg_catalog( session_context: &SessionContext, @@ -945,6 +996,8 @@ pub fn setup_pg_catalog( session_context.register_udf(create_format_type_udf()); session_context.register_udf(create_session_user_udf()); session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone()); + session_context.register_udf(create_pg_get_expr_udf()); + session_context.register_udf(create_pg_get_partkeydef_udf()); Ok(()) } diff --git a/datafusion-postgres/src/pg_catalog/pg_class.rs b/datafusion-postgres/src/pg_catalog/pg_class.rs index 72f2211..2767c6c 100644 --- a/datafusion-postgres/src/pg_catalog/pg_class.rs +++ b/datafusion-postgres/src/pg_catalog/pg_class.rs @@ -63,6 +63,7 @@ impl PgClassTable { Field::new("relrewrite", DataType::Int32, true), // OID of a rule that rewrites this relation Field::new("relfrozenxid", DataType::Int32, false), // All transaction IDs before this have been replaced with a permanent ("frozen") transaction ID Field::new("relminmxid", DataType::Int32, false), // All Multixact IDs before this have been replaced with a transaction ID + Field::new("relpartbound", DataType::Utf8, true), ])); Self { @@ -106,6 +107,7 @@ impl PgClassTable { let mut relrewrites = Vec::new(); let mut relfrozenxids = Vec::new(); let mut relminmxids = Vec::new(); + let mut relpartbound = Vec::new(); let mut oid_cache = this.oid_cache.write().await; // Every time when call pg_catalog we generate a new cache and drop the @@ -190,6 +192,7 @@ impl PgClassTable { relrewrites.push(None); relfrozenxids.push(0); relminmxids.push(0); + relpartbound.push("".to_string()); } } } @@ -231,6 +234,7 @@ impl PgClassTable { Arc::new(Int32Array::from_iter(relrewrites.into_iter())), Arc::new(Int32Array::from(relfrozenxids)), Arc::new(Int32Array::from(relminmxids)), + Arc::new(StringArray::from(relpartbound)), ]; // Create a record batch diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index 0778b0c..e132b91 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -19,6 +19,10 @@ const DBEAVER_QUERIES: &[&str] = &[ "SELECT * FROM pg_catalog.pg_enum WHERE 1<>1 LIMIT 1", "SELECT reltype FROM pg_catalog.pg_class WHERE 1<>1 LIMIT 1", "SELECT t.oid,t.*,c.relkind,format_type(nullif(t.typbasetype, 0), t.typtypmod) as base_type_name, d.description FROM pg_catalog.pg_type t LEFT OUTER JOIN pg_catalog.pg_type et ON et.oid=t.typelem LEFT OUTER JOIN pg_catalog.pg_class c ON c.oid=t.typrelid LEFT OUTER JOIN pg_catalog.pg_description d ON t.oid=d.objoid WHERE t.typname IS NOT NULL AND (c.relkind IS NULL OR c.relkind = 'c') AND (et.typcategory IS NULL OR et.typcategory <> 'C')", + "SELECT c.oid,c.*,d.description,pg_catalog.pg_get_expr(c.relpartbound, c.oid) as partition_expr, pg_catalog.pg_get_partkeydef(c.oid) as partition_key + FROM pg_catalog.pg_class c + LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=c.oid AND d.objsubid=0 AND d.classoid='pg_class'::regclass + WHERE c.relnamespace=11 AND c.relkind not in ('i','I','c')" ]; #[tokio::test] From d84da74016f7585c79bdf30da4463e765363500d Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Thu, 4 Sep 2025 16:41:09 +0800 Subject: [PATCH 13/13] chore: update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 6a82025..762d786 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ project. - Permission control - Built-in `pg_catalog` tables - Built-in postgres functions for common meta queries + - [x] DBeaver compatibility - `datafusion-postgres-cli`: A cli tool starts a postgres compatible server for datafusion supported file formats, just like python's `SimpleHTTPServer`. - `arrow-pg`: A data type mapping, encoding/decoding library for arrow and