Skip to content

Commit 726e7ed

Browse files
committed
fix: refactor has_privilege udfs
1 parent 3956bfb commit 726e7ed

File tree

3 files changed

+89
-76
lines changed

3 files changed

+89
-76
lines changed

datafusion-postgres/src/pg_catalog.rs

Lines changed: 14 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ use std::sync::Arc;
44

55
use async_trait::async_trait;
66
use datafusion::arrow::array::{
7-
as_boolean_array, ArrayRef, BooleanArray, BooleanBuilder, RecordBatch, StringArray,
8-
StringBuilder,
7+
as_boolean_array, ArrayRef, BooleanBuilder, RecordBatch, StringArray, StringBuilder,
98
};
109
use datafusion::arrow::datatypes::{DataType, Field, SchemaRef};
1110
use datafusion::arrow::ipc::reader::FileReader;
@@ -21,6 +20,7 @@ use postgres_types::Oid;
2120
use tokio::sync::RwLock;
2221

2322
mod empty_table;
23+
mod has_privilege_udf;
2424
mod pg_attribute;
2525
mod pg_class;
2626
mod pg_database;
@@ -863,62 +863,6 @@ pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF {
863863
)
864864
}
865865

866-
pub fn create_has_privilege_3param_udf(name: &str) -> ScalarUDF {
867-
// Define the function implementation for 3-parameter version
868-
let func = move |args: &[ColumnarValue]| {
869-
let args = ColumnarValue::values_to_arrays(args)?;
870-
let user = &args[0]; // User (can be name or OID)
871-
let _obj = &args[1]; // Table (can be name or OID)
872-
let _privilege = &args[2]; // Privilege type (SELECT, INSERT, etc.)
873-
874-
// For now, always return true (full access)
875-
let mut builder = BooleanArray::builder(user.len());
876-
for _ in 0..user.len() {
877-
builder.append_value(true);
878-
}
879-
880-
let array: ArrayRef = Arc::new(builder.finish());
881-
882-
Ok(ColumnarValue::Array(array))
883-
};
884-
885-
// Wrap the implementation in a scalar function
886-
create_udf(
887-
name,
888-
vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
889-
DataType::Boolean,
890-
Volatility::Stable,
891-
Arc::new(func),
892-
)
893-
}
894-
895-
pub fn create_has_privilege_2param_udf(name: &str) -> ScalarUDF {
896-
// Define the function implementation for 2-parameter version (current user, table, privilege)
897-
let func = move |args: &[ColumnarValue]| {
898-
let args = ColumnarValue::values_to_arrays(args)?;
899-
let obj = &args[0]; // Table (can be name or OID)
900-
let _privilege = &args[1]; // Privilege type (SELECT, INSERT, etc.)
901-
902-
// For now, always return true (full access for current user)
903-
let mut builder = BooleanArray::builder(obj.len());
904-
for _ in 0..obj.len() {
905-
builder.append_value(true);
906-
}
907-
let array: ArrayRef = Arc::new(builder.finish());
908-
909-
Ok(ColumnarValue::Array(array))
910-
};
911-
912-
// Wrap the implementation in a scalar function
913-
create_udf(
914-
name,
915-
vec![DataType::Utf8, DataType::Utf8],
916-
DataType::Boolean,
917-
Volatility::Stable,
918-
Arc::new(func),
919-
)
920-
}
921-
922866
pub fn create_format_type_udf() -> ScalarUDF {
923867
let func = move |args: &[ColumnarValue]| {
924868
let args = ColumnarValue::values_to_arrays(args)?;
@@ -1015,36 +959,30 @@ pub fn setup_pg_catalog(
1015959
session_context.register_udf(create_current_schemas_udf("pg_catalog.current_schemas"));
1016960
session_context.register_udf(create_version_udf());
1017961
session_context.register_udf(create_pg_get_userbyid_udf());
1018-
session_context.register_udf(create_has_privilege_2param_udf("has_table_privilege"));
1019-
session_context.register_udf(create_has_privilege_2param_udf(
1020-
"pg_catalog.has_table_privilege",
962+
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
963+
"has_table_privilege",
1021964
));
1022-
session_context.register_udf(create_has_privilege_2param_udf("has_schema_privilege"));
1023-
session_context.register_udf(create_has_privilege_2param_udf(
1024-
"pg_catalog.has_schema_privilege",
1025-
));
1026-
session_context.register_udf(create_has_privilege_2param_udf("has_any_column_privilege"));
1027-
session_context.register_udf(create_has_privilege_2param_udf(
1028-
"pg_catalog.has_any_column_privilege",
1029-
));
1030-
session_context.register_udf(create_has_privilege_3param_udf("has_table_privilege"));
1031-
session_context.register_udf(create_has_privilege_3param_udf(
965+
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
1032966
"pg_catalog.has_table_privilege",
1033967
));
1034-
session_context.register_udf(create_has_privilege_3param_udf("has_schema_privilege"));
1035-
session_context.register_udf(create_has_privilege_3param_udf(
968+
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
969+
"has_schema_privilege",
970+
));
971+
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
1036972
"pg_catalog.has_schema_privilege",
1037973
));
1038-
session_context.register_udf(create_has_privilege_3param_udf("has_any_column_privilege"));
1039-
session_context.register_udf(create_has_privilege_3param_udf(
974+
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
975+
"has_any_column_privilege",
976+
));
977+
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
1040978
"pg_catalog.has_any_column_privilege",
1041979
));
1042980
session_context.register_udf(create_pg_table_is_visible("pg_catalog"));
1043981
session_context.register_udf(create_pg_table_is_visible("pg_catalog.pg_table_is_visible"));
1044982
session_context.register_udf(create_format_type_udf());
1045983
session_context.register_udf(create_session_user_udf());
1046984
session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone());
1047-
session_context.register_udf(pg_get_expr_udf::PgGetExprUDF::new().into_scalar_udf());
985+
session_context.register_udf(pg_get_expr_udf::create_pg_get_expr_udf());
1048986
session_context.register_udf(create_pg_get_partkeydef_udf());
1049987

1050988
Ok(())
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::arrow::array::{ArrayRef, BooleanArray};
4+
use datafusion::error::Result;
5+
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF};
6+
use datafusion::{
7+
arrow::datatypes::DataType,
8+
logical_expr::{ScalarUDFImpl, Signature, TypeSignature, Volatility},
9+
};
10+
11+
#[derive(Debug)]
12+
pub struct PgHasPrivilegeUDF {
13+
signature: Signature,
14+
name: String,
15+
}
16+
17+
impl PgHasPrivilegeUDF {
18+
pub(crate) fn new(name: &str) -> PgHasPrivilegeUDF {
19+
Self {
20+
signature: Signature::one_of(
21+
vec![
22+
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
23+
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
24+
],
25+
Volatility::Stable,
26+
),
27+
name: name.to_owned(),
28+
}
29+
}
30+
31+
pub fn into_scalar_udf(self) -> ScalarUDF {
32+
ScalarUDF::new_from_impl(self)
33+
}
34+
}
35+
36+
impl ScalarUDFImpl for PgHasPrivilegeUDF {
37+
fn signature(&self) -> &Signature {
38+
&self.signature
39+
}
40+
41+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
42+
Ok(DataType::Boolean)
43+
}
44+
45+
fn name(&self) -> &str {
46+
self.name.as_ref()
47+
}
48+
49+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
50+
let args = ColumnarValue::values_to_arrays(&args.args)?;
51+
52+
let len = args[0].len();
53+
54+
// For now, always return true (full access for current user)
55+
let mut builder = BooleanArray::builder(len);
56+
for _ in 0..len {
57+
builder.append_value(true);
58+
}
59+
let array: ArrayRef = Arc::new(builder.finish());
60+
61+
Ok(ColumnarValue::Array(array))
62+
}
63+
64+
fn as_any(&self) -> &dyn std::any::Any {
65+
self
66+
}
67+
}
68+
69+
pub fn create_has_privilege_udf(name: &str) -> ScalarUDF {
70+
PgHasPrivilegeUDF::new(name).into_scalar_udf()
71+
}

datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,7 @@ impl ScalarUDFImpl for PgGetExprUDF {
7272
self
7373
}
7474
}
75+
76+
pub fn create_pg_get_expr_udf() -> ScalarUDF {
77+
PgGetExprUDF::new().into_scalar_udf()
78+
}

0 commit comments

Comments
 (0)