Skip to content

Commit 6eda2db

Browse files
authored
feat: implement first \d table query (#169)
* feat: implement first \d table query * feat: update rules WIP * feat: implement more psql \d queries * feat: add support psql \d queries * feat: make format_type work for both int64/int32 * refactor: split and rename sql module * fix: lint * fix: integration test on version() * chore: update flake for integration test dependencies
1 parent 69dfdaf commit 6eda2db

File tree

13 files changed

+1822
-1110
lines changed

13 files changed

+1822
-1110
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
.envrc
44
.vscode
55
.aider*
6-
/test_env
6+
/tests-integration/test_env

datafusion-postgres/src/handlers.rs

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@ use std::collections::HashMap;
22
use std::sync::Arc;
33

44
use crate::auth::{AuthManager, Permission, ResourceType};
5-
use crate::sql::{
6-
parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter,
7-
CurrentUserVariableToSessionUserFunctionCall, FixArrayLiteral, PrependUnqualifiedPgTableName,
8-
RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer,
9-
RewriteArrayAnyAllOperation, SqlStatementRewriteRule,
10-
};
5+
use crate::sql::PostgresCompatibilityParser;
116
use async_trait::async_trait;
127
use datafusion::arrow::datatypes::{DataType, Field, Schema};
138
use datafusion::common::ToDFSchema;
@@ -91,37 +86,22 @@ pub struct DfSessionService {
9186
parser: Arc<Parser>,
9287
timezone: Arc<Mutex<String>>,
9388
auth_manager: Arc<AuthManager>,
94-
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
9589
}
9690

9791
impl DfSessionService {
9892
pub fn new(
9993
session_context: Arc<SessionContext>,
10094
auth_manager: Arc<AuthManager>,
10195
) -> DfSessionService {
102-
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
103-
// make sure blacklist based rewriter it on the top to prevent sql
104-
// being rewritten from other rewriters
105-
Arc::new(BlacklistSqlRewriter::new()),
106-
Arc::new(AliasDuplicatedProjectionRewrite),
107-
Arc::new(ResolveUnqualifiedIdentifer),
108-
Arc::new(RemoveUnsupportedTypes::new()),
109-
Arc::new(RewriteArrayAnyAllOperation),
110-
Arc::new(PrependUnqualifiedPgTableName),
111-
Arc::new(FixArrayLiteral),
112-
Arc::new(RemoveTableFunctionQualifier),
113-
Arc::new(CurrentUserVariableToSessionUserFunctionCall),
114-
];
11596
let parser = Arc::new(Parser {
11697
session_context: session_context.clone(),
117-
sql_rewrite_rules: sql_rewrite_rules.clone(),
98+
sql_parser: PostgresCompatibilityParser::new(),
11899
});
119100
DfSessionService {
120101
session_context,
121102
parser,
122103
timezone: Arc::new(Mutex::new("UTC".to_string())),
123104
auth_manager,
124-
sql_rewrite_rules,
125105
}
126106
}
127107

@@ -457,13 +437,14 @@ impl SimpleQueryHandler for DfSessionService {
457437
return Ok(vec![resp]);
458438
}
459439

460-
let mut statements = parse(query).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
440+
let mut statements = self
441+
.parser
442+
.sql_parser
443+
.parse(query)
444+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
461445

462446
// TODO: deal with multiple statements
463-
let mut statement = statements.remove(0);
464-
465-
// Attempt to rewrite
466-
statement = rewrite(statement, &self.sql_rewrite_rules);
447+
let statement = statements.remove(0);
467448

468449
// TODO: improve statement check by using statement directly
469450
let query = statement.to_string();
@@ -717,7 +698,7 @@ impl ExtendedQueryHandler for DfSessionService {
717698

718699
pub struct Parser {
719700
session_context: Arc<SessionContext>,
720-
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
701+
sql_parser: PostgresCompatibilityParser,
721702
}
722703

723704
impl Parser {
@@ -790,11 +771,11 @@ impl QueryParser for Parser {
790771
return Ok((sql.to_string(), plan));
791772
}
792773

793-
let mut statements = parse(sql).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
794-
let mut statement = statements.remove(0);
795-
796-
// Attempt to rewrite
797-
statement = rewrite(statement, &self.sql_rewrite_rules);
774+
let mut statements = self
775+
.sql_parser
776+
.parse(sql)
777+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
778+
let statement = statements.remove(0);
798779

799780
let query = statement.to_string();
800781

datafusion-postgres/src/pg_catalog.rs

Lines changed: 61 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::pg_catalog::catalog_info::CatalogInfo;
2323

2424
pub mod catalog_info;
2525
pub mod empty_table;
26+
pub mod format_type;
2627
pub mod has_privilege_udf;
2728
pub mod pg_attribute;
2829
pub mod pg_class;
@@ -685,7 +686,7 @@ impl PgCatalogStaticTables {
685686
}
686687
}
687688

688-
pub fn create_current_schemas_udf(name: &str) -> ScalarUDF {
689+
pub fn create_current_schemas_udf() -> ScalarUDF {
689690
// Define the function implementation
690691
let func = move |args: &[ColumnarValue]| {
691692
let args = ColumnarValue::values_to_arrays(args)?;
@@ -708,15 +709,15 @@ pub fn create_current_schemas_udf(name: &str) -> ScalarUDF {
708709

709710
// Wrap the implementation in a scalar function
710711
create_udf(
711-
name,
712+
"current_schemas",
712713
vec![DataType::Boolean],
713714
DataType::List(Arc::new(Field::new("schema", DataType::Utf8, false))),
714715
Volatility::Immutable,
715716
Arc::new(func),
716717
)
717718
}
718719

719-
pub fn create_current_schema_udf(name: &str) -> ScalarUDF {
720+
pub fn create_current_schema_udf() -> ScalarUDF {
720721
// Define the function implementation
721722
let func = move |_args: &[ColumnarValue]| {
722723
// Create a UTF8 array with a single value
@@ -729,15 +730,15 @@ pub fn create_current_schema_udf(name: &str) -> ScalarUDF {
729730

730731
// Wrap the implementation in a scalar function
731732
create_udf(
732-
name,
733+
"current_schema",
733734
vec![],
734735
DataType::Utf8,
735736
Volatility::Immutable,
736737
Arc::new(func),
737738
)
738739
}
739740

740-
pub fn create_current_database_udf(name: &str) -> ScalarUDF {
741+
pub fn create_current_database_udf() -> ScalarUDF {
741742
// Define the function implementation
742743
let func = move |_args: &[ColumnarValue]| {
743744
// Create a UTF8 array with a single value
@@ -750,30 +751,7 @@ pub fn create_current_database_udf(name: &str) -> ScalarUDF {
750751

751752
// Wrap the implementation in a scalar function
752753
create_udf(
753-
name,
754-
vec![],
755-
DataType::Utf8,
756-
Volatility::Immutable,
757-
Arc::new(func),
758-
)
759-
}
760-
761-
pub fn create_version_udf() -> ScalarUDF {
762-
// Define the function implementation
763-
let func = move |_args: &[ColumnarValue]| {
764-
// Create a UTF8 array with version information
765-
let mut builder = StringBuilder::new();
766-
// TODO: improve version string generation
767-
builder
768-
.append_value("DataFusion PostgreSQL 48.0.0 on x86_64-pc-linux-gnu, compiled by Rust");
769-
let array: ArrayRef = Arc::new(builder.finish());
770-
771-
Ok(ColumnarValue::Array(array))
772-
};
773-
774-
// Wrap the implementation in a scalar function
775-
create_udf(
776-
"version",
754+
"current_database",
777755
vec![],
778756
DataType::Utf8,
779757
Volatility::Immutable,
@@ -800,15 +778,15 @@ pub fn create_pg_get_userbyid_udf() -> ScalarUDF {
800778

801779
// Wrap the implementation in a scalar function
802780
create_udf(
803-
"pg_catalog.pg_get_userbyid",
781+
"pg_get_userbyid",
804782
vec![DataType::Int32],
805783
DataType::Utf8,
806784
Volatility::Stable,
807785
Arc::new(func),
808786
)
809787
}
810788

811-
pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF {
789+
pub fn create_pg_table_is_visible() -> ScalarUDF {
812790
// Define the function implementation
813791
let func = move |args: &[ColumnarValue]| {
814792
let args = ColumnarValue::values_to_arrays(args)?;
@@ -827,24 +805,42 @@ pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF {
827805

828806
// Wrap the implementation in a scalar function
829807
create_udf(
830-
name,
808+
"pg_table_is_visible",
831809
vec![DataType::Int32],
832810
DataType::Boolean,
833811
Volatility::Stable,
834812
Arc::new(func),
835813
)
836814
}
837815

838-
pub fn create_format_type_udf() -> ScalarUDF {
816+
pub fn create_session_user_udf() -> ScalarUDF {
817+
let func = move |_args: &[ColumnarValue]| {
818+
let mut builder = StringBuilder::new();
819+
// TODO: return real user
820+
builder.append_value("postgres");
821+
822+
let array: ArrayRef = Arc::new(builder.finish());
823+
824+
Ok(ColumnarValue::Array(array))
825+
};
826+
827+
create_udf(
828+
"session_user",
829+
vec![],
830+
DataType::Utf8,
831+
Volatility::Stable,
832+
Arc::new(func),
833+
)
834+
}
835+
836+
pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
839837
let func = move |args: &[ColumnarValue]| {
840838
let args = ColumnarValue::values_to_arrays(args)?;
841-
let type_oids = &args[0]; // Table (can be name or OID)
842-
let _type_mods = &args[1]; // Privilege type (SELECT, INSERT, etc.)
839+
let oid = &args[0];
843840

844-
// For now, always return true (full access for current user)
845841
let mut builder = StringBuilder::new();
846-
for _ in 0..type_oids.len() {
847-
builder.append_value("???");
842+
for _ in 0..oid.len() {
843+
builder.append_value("");
848844
}
849845

850846
let array: ArrayRef = Arc::new(builder.finish());
@@ -853,42 +849,46 @@ pub fn create_format_type_udf() -> ScalarUDF {
853849
};
854850

855851
create_udf(
856-
"format_type",
857-
vec![DataType::Int64, DataType::Int32],
852+
"pg_get_partkeydef",
853+
vec![DataType::Utf8],
858854
DataType::Utf8,
859855
Volatility::Stable,
860856
Arc::new(func),
861857
)
862858
}
863859

864-
pub fn create_session_user_udf() -> ScalarUDF {
865-
let func = move |_args: &[ColumnarValue]| {
866-
let mut builder = StringBuilder::new();
867-
// TODO: return real user
868-
builder.append_value("postgres");
860+
pub fn create_pg_relation_is_publishable_udf() -> ScalarUDF {
861+
let func = move |args: &[ColumnarValue]| {
862+
let args = ColumnarValue::values_to_arrays(args)?;
863+
let oid = &args[0];
864+
865+
let mut builder = BooleanBuilder::new();
866+
for _ in 0..oid.len() {
867+
builder.append_value(true);
868+
}
869869

870870
let array: ArrayRef = Arc::new(builder.finish());
871871

872872
Ok(ColumnarValue::Array(array))
873873
};
874874

875875
create_udf(
876-
"session_user",
877-
vec![],
878-
DataType::Utf8,
876+
"pg_relation_is_publishable",
877+
vec![DataType::Int32],
878+
DataType::Boolean,
879879
Volatility::Stable,
880880
Arc::new(func),
881881
)
882882
}
883883

884-
pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
884+
pub fn create_pg_get_statisticsobjdef_columns_udf() -> ScalarUDF {
885885
let func = move |args: &[ColumnarValue]| {
886886
let args = ColumnarValue::values_to_arrays(args)?;
887887
let oid = &args[0];
888888

889-
let mut builder = StringBuilder::new();
889+
let mut builder = BooleanBuilder::new();
890890
for _ in 0..oid.len() {
891-
builder.append_value("");
891+
builder.append_null();
892892
}
893893

894894
let array: ArrayRef = Arc::new(builder.finish());
@@ -897,8 +897,8 @@ pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
897897
};
898898

899899
create_udf(
900-
"pg_catalog.pg_get_partkeydef",
901-
vec![DataType::Utf8],
900+
"pg_get_statisticsobjdef_columns",
901+
vec![DataType::UInt32],
902902
DataType::Utf8,
903903
Volatility::Stable,
904904
Arc::new(func),
@@ -924,38 +924,28 @@ pub fn setup_pg_catalog(
924924
})?
925925
.register_schema("pg_catalog", Arc::new(pg_catalog))?;
926926

927-
session_context.register_udf(create_current_database_udf("current_database"));
928-
session_context.register_udf(create_current_schema_udf("current_schema"));
929-
session_context.register_udf(create_current_schema_udf("pg_catalog.current_schema"));
930-
session_context.register_udf(create_current_schemas_udf("current_schemas"));
931-
session_context.register_udf(create_current_schemas_udf("pg_catalog.current_schemas"));
932-
session_context.register_udf(create_version_udf());
927+
session_context.register_udf(create_current_database_udf());
928+
session_context.register_udf(create_current_schema_udf());
929+
session_context.register_udf(create_current_schemas_udf());
930+
// session_context.register_udf(create_version_udf());
933931
session_context.register_udf(create_pg_get_userbyid_udf());
934932
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
935933
"has_table_privilege",
936934
));
937-
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
938-
"pg_catalog.has_table_privilege",
939-
));
940935
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
941936
"has_schema_privilege",
942937
));
943-
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
944-
"pg_catalog.has_schema_privilege",
945-
));
946938
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
947939
"has_any_column_privilege",
948940
));
949-
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
950-
"pg_catalog.has_any_column_privilege",
951-
));
952-
session_context.register_udf(create_pg_table_is_visible("pg_table_is_visible"));
953-
session_context.register_udf(create_pg_table_is_visible("pg_catalog.pg_table_is_visible"));
954-
session_context.register_udf(create_format_type_udf());
941+
session_context.register_udf(create_pg_table_is_visible());
942+
session_context.register_udf(format_type::create_format_type_udf());
955943
session_context.register_udf(create_session_user_udf());
956944
session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone());
957945
session_context.register_udf(pg_get_expr_udf::create_pg_get_expr_udf());
958946
session_context.register_udf(create_pg_get_partkeydef_udf());
947+
session_context.register_udf(create_pg_relation_is_publishable_udf());
948+
session_context.register_udf(create_pg_get_statisticsobjdef_columns_udf());
959949

960950
Ok(())
961951
}

0 commit comments

Comments
 (0)