Skip to content

Commit 0503bc9

Browse files
committed
feat: improve support for current_user and privilege functions
Signed-off-by: Ning Sun <[email protected]>
1 parent 8350ac4 commit 0503bc9

File tree

3 files changed

+108
-17
lines changed

3 files changed

+108
-17
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ use std::sync::Arc;
33

44
use crate::auth::{AuthManager, Permission, ResourceType};
55
use crate::sql::{
6-
parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter, FixArrayLiteral,
7-
PrependUnqualifiedPgTableName, RemoveTableFunctionQualifier, RemoveUnsupportedTypes,
8-
ResolveUnqualifiedIdentifer, RewriteArrayAnyAllOperation, SqlStatementRewriteRule,
6+
parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter,
7+
CurrentUserVariableToSessionUserFunctionCall, FixArrayLiteral, PrependUnqualifiedPgTableName,
8+
RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer,
9+
RewriteArrayAnyAllOperation, SqlStatementRewriteRule,
910
};
1011
use async_trait::async_trait;
1112
use datafusion::arrow::datatypes::DataType;
@@ -107,6 +108,7 @@ impl DfSessionService {
107108
Arc::new(PrependUnqualifiedPgTableName),
108109
Arc::new(FixArrayLiteral),
109110
Arc::new(RemoveTableFunctionQualifier),
111+
Arc::new(CurrentUserVariableToSessionUserFunctionCall),
110112
];
111113
let parser = Arc::new(Parser {
112114
session_context: session_context.clone(),

datafusion-postgres/src/pg_catalog.rs

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ pub fn create_pg_get_userbyid_udf() -> ScalarUDF {
836836
)
837837
}
838838

839-
pub fn create_pg_table_is_visible() -> ScalarUDF {
839+
pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF {
840840
// Define the function implementation
841841
let func = move |args: &[ColumnarValue]| {
842842
let args = ColumnarValue::values_to_arrays(args)?;
@@ -855,20 +855,20 @@ pub fn create_pg_table_is_visible() -> ScalarUDF {
855855

856856
// Wrap the implementation in a scalar function
857857
create_udf(
858-
"pg_catalog.pg_table_is_visible",
858+
name,
859859
vec![DataType::Int32],
860860
DataType::Boolean,
861861
Volatility::Stable,
862862
Arc::new(func),
863863
)
864864
}
865865

866-
pub fn create_has_table_privilege_3param_udf() -> ScalarUDF {
866+
pub fn create_has_privilege_3param_udf(name: &str) -> ScalarUDF {
867867
// Define the function implementation for 3-parameter version
868868
let func = move |args: &[ColumnarValue]| {
869869
let args = ColumnarValue::values_to_arrays(args)?;
870870
let user = &args[0]; // User (can be name or OID)
871-
let _table = &args[1]; // Table (can be name or OID)
871+
let _obj = &args[1]; // Table (can be name or OID)
872872
let _privilege = &args[2]; // Privilege type (SELECT, INSERT, etc.)
873873

874874
// For now, always return true (full access)
@@ -884,24 +884,24 @@ pub fn create_has_table_privilege_3param_udf() -> ScalarUDF {
884884

885885
// Wrap the implementation in a scalar function
886886
create_udf(
887-
"has_table_privilege",
887+
name,
888888
vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
889889
DataType::Boolean,
890890
Volatility::Stable,
891891
Arc::new(func),
892892
)
893893
}
894894

895-
pub fn create_has_table_privilege_2param_udf() -> ScalarUDF {
895+
pub fn create_has_privilege_2param_udf(name: &str) -> ScalarUDF {
896896
// Define the function implementation for 2-parameter version (current user, table, privilege)
897897
let func = move |args: &[ColumnarValue]| {
898898
let args = ColumnarValue::values_to_arrays(args)?;
899-
let table = &args[0]; // Table (can be name or OID)
899+
let obj = &args[0]; // Table (can be name or OID)
900900
let _privilege = &args[1]; // Privilege type (SELECT, INSERT, etc.)
901901

902902
// For now, always return true (full access for current user)
903-
let mut builder = BooleanArray::builder(table.len());
904-
for _ in 0..table.len() {
903+
let mut builder = BooleanArray::builder(obj.len());
904+
for _ in 0..obj.len() {
905905
builder.append_value(true);
906906
}
907907
let array: ArrayRef = Arc::new(builder.finish());
@@ -911,7 +911,7 @@ pub fn create_has_table_privilege_2param_udf() -> ScalarUDF {
911911

912912
// Wrap the implementation in a scalar function
913913
create_udf(
914-
"has_table_privilege",
914+
name,
915915
vec![DataType::Utf8, DataType::Utf8],
916916
DataType::Boolean,
917917
Volatility::Stable,
@@ -970,7 +970,6 @@ pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
970970
let args = ColumnarValue::values_to_arrays(args)?;
971971
let oid = &args[0];
972972

973-
// For now, always return true (full access for current user)
974973
let mut builder = StringBuilder::new();
975974
for _ in 0..oid.len() {
976975
builder.append_value("");
@@ -1016,9 +1015,32 @@ pub fn setup_pg_catalog(
10161015
session_context.register_udf(create_current_schemas_udf("pg_catalog.current_schemas"));
10171016
session_context.register_udf(create_version_udf());
10181017
session_context.register_udf(create_pg_get_userbyid_udf());
1019-
session_context.register_udf(create_has_table_privilege_2param_udf());
1020-
session_context.register_udf(create_has_table_privilege_3param_udf());
1021-
session_context.register_udf(create_pg_table_is_visible());
1018+
session_context.register_udf(create_has_privilege_2param_udf("has_table_privilege"));
1019+
session_context.register_udf(create_has_privilege_2param_udf(
1020+
"pg_catalog.has_table_privilege",
1021+
));
1022+
session_context.register_udf(create_has_privilege_2param_udf("has_schema_privilege"));
1023+
session_context.register_udf(create_has_privilege_2param_udf(
1024+
"pg_catalog.has_schema_privilege",
1025+
));
1026+
session_context.register_udf(create_has_privilege_2param_udf("has_any_column_privilege"));
1027+
session_context.register_udf(create_has_privilege_2param_udf(
1028+
"pg_catalog.has_any_column_privilege",
1029+
));
1030+
session_context.register_udf(create_has_privilege_3param_udf("has_table_privilege"));
1031+
session_context.register_udf(create_has_privilege_3param_udf(
1032+
"pg_catalog.has_table_privilege",
1033+
));
1034+
session_context.register_udf(create_has_privilege_3param_udf("has_schema_privilege"));
1035+
session_context.register_udf(create_has_privilege_3param_udf(
1036+
"pg_catalog.has_schema_privilege",
1037+
));
1038+
session_context.register_udf(create_has_privilege_3param_udf("has_any_column_privilege"));
1039+
session_context.register_udf(create_has_privilege_3param_udf(
1040+
"pg_catalog.has_any_column_privilege",
1041+
));
1042+
session_context.register_udf(create_pg_table_is_visible("pg_catalog"));
1043+
session_context.register_udf(create_pg_table_is_visible("pg_catalog.pg_table_is_visible"));
10221044
session_context.register_udf(create_format_type_udf());
10231045
session_context.register_udf(create_session_user_udf());
10241046
session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone());

datafusion-postgres/src/sql.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,57 @@ impl SqlStatementRewriteRule for RemoveTableFunctionQualifier {
613613
}
614614
}
615615

616+
/// Replace `current_user` with `session_user()`
617+
#[derive(Debug)]
618+
pub struct CurrentUserVariableToSessionUserFunctionCall;
619+
620+
struct CurrentUserVariableToSessionUserFunctionCallVisitor;
621+
622+
impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor {
623+
type Break = ();
624+
625+
fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
626+
if let Expr::Identifier(ident) = expr {
627+
if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" {
628+
*expr = Expr::Function(Function {
629+
name: ObjectName::from(vec![Ident::new("session_user")]),
630+
args: FunctionArguments::None,
631+
uses_odbc_syntax: false,
632+
parameters: FunctionArguments::None,
633+
filter: None,
634+
null_treatment: None,
635+
over: None,
636+
within_group: vec![],
637+
});
638+
}
639+
}
640+
641+
if let Expr::Function(func) = expr {
642+
let fname = func
643+
.name
644+
.0
645+
.iter()
646+
.map(|ident| ident.to_string())
647+
.collect::<Vec<String>>()
648+
.join(".");
649+
if fname.to_lowercase() == "current_user" {
650+
func.name = ObjectName::from(vec![Ident::new("session_user")])
651+
}
652+
}
653+
654+
ControlFlow::Continue(())
655+
}
656+
}
657+
658+
impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall {
659+
fn rewrite(&self, mut s: Statement) -> Statement {
660+
let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor;
661+
662+
let _ = s.visit(&mut visitor);
663+
s
664+
}
665+
}
666+
616667
#[cfg(test)]
617668
mod tests {
618669
use super::*;
@@ -802,4 +853,20 @@ mod tests {
802853
"SELECT * FROM pg_get_keywords()"
803854
);
804855
}
856+
857+
#[test]
858+
fn test_current_user() {
859+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
860+
vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)];
861+
862+
assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user");
863+
864+
assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user");
865+
866+
assert_rewrite!(
867+
&rules,
868+
"SELECT is_null(current_user)",
869+
"SELECT is_null(session_user)"
870+
);
871+
}
805872
}

0 commit comments

Comments
 (0)