Skip to content

Commit 7f91a89

Browse files
authored
test: include queries for pgcli startup (#149)
* test: include queries for pgcli startup * test: add compatibility tests for pgcli startup queries * fix: fallback to stringarray or string if no type information provided Signed-off-by: Ning Sun <[email protected]> * chore: add pgcli on readme * chore: udpate comments Signed-off-by: Ning Sun <[email protected]> * chore: remove dbg --------- Signed-off-by: Ning Sun <[email protected]>
1 parent cb841ea commit 7f91a89

File tree

8 files changed

+424
-69
lines changed

8 files changed

+424
-69
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ project.
2323
- Built-in `pg_catalog` tables
2424
- Built-in postgres functions for common meta queries
2525
- [x] DBeaver compatibility
26+
- [x] pgcli compatibility
2627
- `datafusion-postgres-cli`: A cli tool starts a postgres compatible server for
2728
datafusion supported file formats, just like python's `SimpleHTTPServer`.
2829
- `arrow-pg`: A data type mapping, encoding/decoding library for arrow and

arrow-pg/src/datatypes/df.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ where
6666
} else if let Some(infer_type) = inferenced_type {
6767
into_pg_type(infer_type)
6868
} else {
69-
// Default to TEXT for untyped parameters in extended queries
70-
// This allows arithmetic operations to work with implicit casting
71-
Ok(Type::TEXT)
69+
Ok(Type::UNKNOWN)
7270
}
7371
}
7472

@@ -262,10 +260,28 @@ where
262260
}
263261
// TODO: add more advanced types (composite types, ranges, etc.)
264262
_ => {
265-
// Default to string/text for unsupported parameter types
266-
// This allows graceful degradation instead of fatal errors
263+
// the client didn't provide type information and we are also
264+
// unable to inference the type, or it's a type that we haven't
265+
// supported:
266+
//
267+
// In this case we retry to resolve it as String or StringArray
267268
let value = portal.parameter::<String>(i, &pg_type)?;
268-
deserialized_params.push(ScalarValue::Utf8(value));
269+
if let Some(value) = value {
270+
if value.starts_with('{') && value.ends_with('}') {
271+
// Looks like an array
272+
let items = value.trim_matches(|c| c == '{' || c == '}' || c == ' ');
273+
let items = items.split(',').map(|s| s.trim());
274+
let scalar_values: Vec<ScalarValue> = items
275+
.map(|s| ScalarValue::Utf8(Some(s.to_string())))
276+
.collect();
277+
278+
deserialized_params.push(ScalarValue::List(
279+
ScalarValue::new_list_nullable(&scalar_values, &DataType::Utf8),
280+
));
281+
} else {
282+
deserialized_params.push(ScalarValue::Utf8(Some(value)));
283+
}
284+
}
269285
}
270286
}
271287
}

datafusion-postgres/src/handlers.rs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ use std::sync::Arc;
33

44
use crate::auth::{AuthManager, Permission, ResourceType};
55
use crate::sql::{
6-
parse, rewrite, AliasDuplicatedProjectionRewrite, FixArrayLiteral, PrependUnqualifiedTableName,
7-
RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer,
8-
RewriteArrayAnyAllOperation, SqlStatementRewriteRule,
6+
parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter, FixArrayLiteral,
7+
PrependUnqualifiedPgTableName, RemoveTableFunctionQualifier, RemoveUnsupportedTypes,
8+
ResolveUnqualifiedIdentifer, RewriteArrayAnyAllOperation, SqlStatementRewriteRule,
99
};
1010
use async_trait::async_trait;
1111
use datafusion::arrow::datatypes::DataType;
1212
use datafusion::logical_expr::LogicalPlan;
1313
use datafusion::prelude::*;
1414
use datafusion::sql::parser::Statement;
15-
use log::warn;
15+
use log::{info, warn};
1616
use pgwire::api::auth::noop::NoopStartupHandler;
1717
use pgwire::api::auth::StartupHandler;
1818
use pgwire::api::portal::{Format, Portal};
@@ -23,7 +23,7 @@ use pgwire::api::results::{
2323
};
2424
use pgwire::api::stmt::QueryParser;
2525
use pgwire::api::stmt::StoredStatement;
26-
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
26+
use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
2727
use pgwire::error::{PgWireError, PgWireResult};
2828
use pgwire::messages::response::TransactionStatus;
2929
use tokio::sync::Mutex;
@@ -65,6 +65,21 @@ impl PgWireServerHandlers for HandlerFactory {
6565
fn startup_handler(&self) -> Arc<impl StartupHandler> {
6666
Arc::new(SimpleStartupHandler)
6767
}
68+
69+
fn error_handler(&self) -> Arc<impl ErrorHandler> {
70+
Arc::new(LoggingErrorHandler)
71+
}
72+
}
73+
74+
struct LoggingErrorHandler;
75+
76+
impl ErrorHandler for LoggingErrorHandler {
77+
fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
78+
where
79+
C: ClientInfo,
80+
{
81+
info!("Sending error: {error}")
82+
}
6883
}
6984

7085
/// The pgwire handler backed by a datafusion `SessionContext`
@@ -82,11 +97,14 @@ impl DfSessionService {
8297
auth_manager: Arc<AuthManager>,
8398
) -> DfSessionService {
8499
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
100+
// make sure blacklist based rewriter it on the top to prevent sql
101+
// being rewritten from other rewriters
102+
Arc::new(BlacklistSqlRewriter::new()),
85103
Arc::new(AliasDuplicatedProjectionRewrite),
86104
Arc::new(ResolveUnqualifiedIdentifer),
87105
Arc::new(RemoveUnsupportedTypes::new()),
88106
Arc::new(RewriteArrayAnyAllOperation),
89-
Arc::new(PrependUnqualifiedTableName::new()),
107+
Arc::new(PrependUnqualifiedPgTableName),
90108
Arc::new(FixArrayLiteral),
91109
Arc::new(RemoveTableFunctionQualifier),
92110
];
@@ -649,7 +667,9 @@ impl ExtendedQueryHandler for DfSessionService {
649667
let param_types = plan
650668
.get_parameter_types()
651669
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
670+
652671
let param_values = df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
672+
653673
let plan = plan
654674
.clone()
655675
.replace_params_with_values(&param_values)
@@ -678,12 +698,10 @@ impl ExtendedQueryHandler for DfSessionService {
678698
})?
679699
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
680700
} else {
681-
match self.session_context.execute_logical_plan(optimised).await {
682-
Ok(df) => df,
683-
Err(e) => {
684-
return Err(PgWireError::ApiError(Box::new(e)));
685-
}
686-
}
701+
self.session_context
702+
.execute_logical_plan(optimised)
703+
.await
704+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
687705
}
688706
};
689707
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;

datafusion-postgres/src/pg_catalog.rs

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use tokio::sync::RwLock;
2323
mod pg_attribute;
2424
mod pg_class;
2525
mod pg_database;
26+
mod pg_get_expr_udf;
2627
mod pg_namespace;
2728
mod pg_settings;
2829

@@ -917,32 +918,6 @@ pub fn create_session_user_udf() -> ScalarUDF {
917918
)
918919
}
919920

920-
pub fn create_pg_get_expr_udf() -> ScalarUDF {
921-
let func = move |args: &[ColumnarValue]| {
922-
let args = ColumnarValue::values_to_arrays(args)?;
923-
let expr = &args[0];
924-
let _oid = &args[1];
925-
926-
// For now, always return true (full access for current user)
927-
let mut builder = StringBuilder::new();
928-
for _ in 0..expr.len() {
929-
builder.append_value("");
930-
}
931-
932-
let array: ArrayRef = Arc::new(builder.finish());
933-
934-
Ok(ColumnarValue::Array(array))
935-
};
936-
937-
create_udf(
938-
"pg_catalog.pg_get_expr",
939-
vec![DataType::Utf8, DataType::Int32],
940-
DataType::Utf8,
941-
Volatility::Stable,
942-
Arc::new(func),
943-
)
944-
}
945-
946921
pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
947922
let func = move |args: &[ColumnarValue]| {
948923
let args = ColumnarValue::values_to_arrays(args)?;
@@ -996,7 +971,7 @@ pub fn setup_pg_catalog(
996971
session_context.register_udf(create_format_type_udf());
997972
session_context.register_udf(create_session_user_udf());
998973
session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone());
999-
session_context.register_udf(create_pg_get_expr_udf());
974+
session_context.register_udf(pg_get_expr_udf::PgGetExprUDF::new().into_scalar_udf());
1000975
session_context.register_udf(create_pg_get_partkeydef_udf());
1001976

1002977
Ok(())
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::arrow::array::{ArrayRef, StringBuilder};
4+
use datafusion::error::Result;
5+
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF};
6+
use datafusion::{
7+
arrow::datatypes::DataType,
8+
logical_expr::{ScalarUDFImpl, Signature, TypeSignature, Volatility},
9+
};
10+
11+
#[derive(Debug)]
12+
pub struct PgGetExprUDF {
13+
signature: Signature,
14+
name: &'static str,
15+
}
16+
17+
impl PgGetExprUDF {
18+
pub(crate) fn new() -> PgGetExprUDF {
19+
Self {
20+
signature: Signature::one_of(
21+
vec![
22+
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int32]),
23+
TypeSignature::Exact(vec![DataType::Utf8, DataType::UInt32]),
24+
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
25+
TypeSignature::Exact(vec![DataType::Utf8, DataType::UInt64]),
26+
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int32, DataType::Boolean]),
27+
TypeSignature::Exact(vec![DataType::Utf8, DataType::UInt32, DataType::Boolean]),
28+
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Boolean]),
29+
TypeSignature::Exact(vec![DataType::Utf8, DataType::UInt64, DataType::Boolean]),
30+
],
31+
Volatility::Stable,
32+
),
33+
name: "pg_catalog.pg_get_expr",
34+
}
35+
}
36+
37+
pub fn into_scalar_udf(self) -> ScalarUDF {
38+
ScalarUDF::new_from_impl(self).with_aliases(vec!["pg_get_expr"])
39+
}
40+
}
41+
42+
impl ScalarUDFImpl for PgGetExprUDF {
43+
fn signature(&self) -> &Signature {
44+
&self.signature
45+
}
46+
47+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
48+
Ok(DataType::Utf8)
49+
}
50+
51+
fn name(&self) -> &str {
52+
self.name
53+
}
54+
55+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
56+
let args = ColumnarValue::values_to_arrays(&args.args)?;
57+
let expr = &args[0];
58+
let _oid = &args[1];
59+
60+
// For now, always return true (full access for current user)
61+
let mut builder = StringBuilder::new();
62+
for _ in 0..expr.len() {
63+
builder.append_value("");
64+
}
65+
66+
let array: ArrayRef = Arc::new(builder.finish());
67+
68+
Ok(ColumnarValue::Array(array))
69+
}
70+
71+
fn as_any(&self) -> &dyn std::any::Any {
72+
self
73+
}
74+
}

0 commit comments

Comments
 (0)