Skip to content

Commit 7c1aa14

Browse files
committed
Merge branch 'master' into feat/pg-catalog-sql-dbeaver
2 parents 975fef0 + e5644ec commit 7c1aa14

File tree

3 files changed

+311
-22
lines changed

3 files changed

+311
-22
lines changed

arrow-pg/src/encoder.rs

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -513,38 +513,31 @@ pub fn encode_value<T: Encoder>(
513513
if arr.is_null(idx) {
514514
return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
515515
}
516-
// Get the dictionary values, ignoring keys
517-
// We'll use Int32Type as a common key type, but we're only interested in values
518-
macro_rules! get_dict_values {
516+
// Get the dictionary values and the mapped row index
517+
macro_rules! get_dict_values_and_index {
519518
($key_type:ty) => {
520519
arr.as_any()
521520
.downcast_ref::<DictionaryArray<$key_type>>()
522-
.map(|dict| dict.values())
521+
.map(|dict| (dict.values(), dict.keys().value(idx) as usize))
523522
};
524523
}
525524

526525
// Try to extract values using different key types
527-
let values = get_dict_values!(Int8Type)
528-
.or_else(|| get_dict_values!(Int16Type))
529-
.or_else(|| get_dict_values!(Int32Type))
530-
.or_else(|| get_dict_values!(Int64Type))
531-
.or_else(|| get_dict_values!(UInt8Type))
532-
.or_else(|| get_dict_values!(UInt16Type))
533-
.or_else(|| get_dict_values!(UInt32Type))
534-
.or_else(|| get_dict_values!(UInt64Type))
526+
let (values, idx) = get_dict_values_and_index!(Int8Type)
527+
.or_else(|| get_dict_values_and_index!(Int16Type))
528+
.or_else(|| get_dict_values_and_index!(Int32Type))
529+
.or_else(|| get_dict_values_and_index!(Int64Type))
530+
.or_else(|| get_dict_values_and_index!(UInt8Type))
531+
.or_else(|| get_dict_values_and_index!(UInt16Type))
532+
.or_else(|| get_dict_values_and_index!(UInt32Type))
533+
.or_else(|| get_dict_values_and_index!(UInt64Type))
535534
.ok_or_else(|| {
536535
ToSqlError::from(format!(
537536
"Unsupported dictionary key type for value type {value_type}"
538537
))
539538
})?;
540539

541-
// If the dictionary has only one value, treat it as a primitive
542-
if values.len() == 1 {
543-
encode_value(encoder, values, 0, type_, format)?
544-
} else {
545-
// Otherwise, use value directly indexed by values array
546-
encode_value(encoder, values, idx, type_, format)?
547-
}
540+
encode_value(encoder, values, idx, type_, format)?
548541
}
549542
_ => {
550543
return Err(PgWireError::ApiError(ToSqlError::from(format!(
@@ -557,3 +550,48 @@ pub fn encode_value<T: Encoder>(
557550

558551
Ok(())
559552
}
553+
554+
#[cfg(test)]
555+
mod tests {
556+
use super::*;
557+
558+
#[test]
559+
fn encodes_dictionary_array() {
560+
#[derive(Default)]
561+
struct MockEncoder {
562+
encoded_value: String,
563+
}
564+
565+
impl Encoder for MockEncoder {
566+
fn encode_field_with_type_and_format<T>(
567+
&mut self,
568+
value: &T,
569+
data_type: &Type,
570+
_format: FieldFormat,
571+
) -> PgWireResult<()>
572+
where
573+
T: ToSql + ToSqlText + Sized,
574+
{
575+
let mut bytes = BytesMut::new();
576+
let _sql_text = value.to_sql_text(data_type, &mut bytes);
577+
let string = String::from_utf8((&bytes).to_vec());
578+
self.encoded_value = string.unwrap();
579+
Ok(())
580+
}
581+
}
582+
583+
let val = "~!@&$[]()@@!!";
584+
let value = StringArray::from_iter_values([val]);
585+
let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
586+
let dict_arr: Arc<dyn Array> =
587+
Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());
588+
589+
let mut encoder = MockEncoder::default();
590+
591+
let result = encode_value(&mut encoder, &dict_arr, 2, &Type::TEXT, FieldFormat::Text);
592+
593+
assert!(result.is_ok());
594+
595+
assert!(encoder.encoded_value == val);
596+
}
597+
}

datafusion-postgres/src/handlers.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use std::collections::HashMap;
22
use std::sync::Arc;
33

44
use crate::auth::{AuthManager, Permission, ResourceType};
5-
use crate::sql::{parse, rewrite, AliasDuplicatedProjectionRewrite, SqlStatementRewriteRule};
5+
use crate::sql::{
6+
parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes,
7+
ResolveUnqualifiedIdentifer, SqlStatementRewriteRule,
8+
};
69
use async_trait::async_trait;
710
use datafusion::arrow::datatypes::DataType;
811
use datafusion::logical_expr::LogicalPlan;
@@ -74,8 +77,11 @@ impl DfSessionService {
7477
session_context: Arc<SessionContext>,
7578
auth_manager: Arc<AuthManager>,
7679
) -> DfSessionService {
77-
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
78-
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
80+
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
81+
Arc::new(AliasDuplicatedProjectionRewrite),
82+
Arc::new(ResolveUnqualifiedIdentifer),
83+
Arc::new(RemoveUnsupportedTypes::new()),
84+
];
7985
let parser = Arc::new(Parser {
8086
session_context: session_context.clone(),
8187
sql_rewrite_rules: sql_rewrite_rules.clone(),

datafusion-postgres/src/sql.rs

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
use std::collections::HashSet;
12
use std::sync::Arc;
23

34
use datafusion::sql::sqlparser::ast::Expr;
45
use datafusion::sql::sqlparser::ast::Ident;
6+
use datafusion::sql::sqlparser::ast::OrderByKind;
7+
use datafusion::sql::sqlparser::ast::Query;
58
use datafusion::sql::sqlparser::ast::Select;
69
use datafusion::sql::sqlparser::ast::SelectItem;
710
use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
811
use datafusion::sql::sqlparser::ast::SetExpr;
912
use datafusion::sql::sqlparser::ast::Statement;
13+
use datafusion::sql::sqlparser::ast::TableFactor;
14+
use datafusion::sql::sqlparser::ast::TableWithJoins;
15+
use datafusion::sql::sqlparser::ast::Value;
1016
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
1117
use datafusion::sql::sqlparser::parser::Parser;
1218
use datafusion::sql::sqlparser::parser::ParserError;
@@ -135,6 +141,188 @@ impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
135141
}
136142
}
137143

144+
/// Prepend qualifier for order by or filter when there is qualified wildcard
145+
///
146+
/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not
147+
/// accepted by datafusion.
148+
#[derive(Debug)]
149+
pub struct ResolveUnqualifiedIdentifer;
150+
151+
impl ResolveUnqualifiedIdentifer {
152+
fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
153+
if let SetExpr::Select(select) = query.body.as_mut() {
154+
// Step 1: Find all table aliases from FROM and JOIN clauses.
155+
let table_aliases = Self::get_table_aliases(&select.from);
156+
157+
// Step 2: Check for a single qualified wildcard in the projection.
158+
let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
159+
if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
160+
return; // Conditions not met.
161+
}
162+
163+
let wildcard_alias = qualified_wildcard_alias.unwrap();
164+
165+
// Step 3: Rewrite expressions in the WHERE and ORDER BY clauses.
166+
if let Some(selection) = &mut select.selection {
167+
Self::rewrite_expr(selection, &wildcard_alias, &table_aliases);
168+
}
169+
170+
if let Some(OrderByKind::Expressions(order_by_exprs)) =
171+
query.order_by.as_mut().map(|o| &mut o.kind)
172+
{
173+
for order_by_expr in order_by_exprs {
174+
Self::rewrite_expr(&mut order_by_expr.expr, &wildcard_alias, &table_aliases);
175+
}
176+
}
177+
}
178+
}
179+
180+
fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
181+
let mut aliases = HashSet::new();
182+
for table_with_joins in tables {
183+
if let TableFactor::Table {
184+
alias: Some(alias), ..
185+
} = &table_with_joins.relation
186+
{
187+
aliases.insert(alias.name.value.clone());
188+
}
189+
for join in &table_with_joins.joins {
190+
if let TableFactor::Table {
191+
alias: Some(alias), ..
192+
} = &join.relation
193+
{
194+
aliases.insert(alias.name.value.clone());
195+
}
196+
}
197+
}
198+
aliases
199+
}
200+
201+
fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
202+
let mut qualified_wildcards = projection
203+
.iter()
204+
.filter_map(|item| {
205+
if let SelectItem::QualifiedWildcard(
206+
SelectItemQualifiedWildcardKind::ObjectName(objname),
207+
_,
208+
) = item
209+
{
210+
Some(
211+
objname
212+
.0
213+
.iter()
214+
.map(|v| v.as_ident().unwrap().value.clone())
215+
.collect::<Vec<_>>()
216+
.join("."),
217+
)
218+
} else {
219+
None
220+
}
221+
})
222+
.collect::<Vec<_>>();
223+
224+
if qualified_wildcards.len() == 1 {
225+
Some(qualified_wildcards.remove(0))
226+
} else {
227+
None
228+
}
229+
}
230+
231+
fn rewrite_expr(expr: &mut Expr, wildcard_alias: &str, table_aliases: &HashSet<String>) {
232+
match expr {
233+
Expr::Identifier(ident) => {
234+
// If the identifier is not a table alias itself, rewrite it.
235+
if !table_aliases.contains(&ident.value) {
236+
*expr = Expr::CompoundIdentifier(vec![
237+
Ident::new(wildcard_alias.to_string()),
238+
ident.clone(),
239+
]);
240+
}
241+
}
242+
Expr::BinaryOp { left, right, .. } => {
243+
Self::rewrite_expr(left, wildcard_alias, table_aliases);
244+
Self::rewrite_expr(right, wildcard_alias, table_aliases);
245+
}
246+
// Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.)
247+
_ => {}
248+
}
249+
}
250+
}
251+
252+
impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
253+
fn rewrite(&self, mut statement: Statement) -> Statement {
254+
if let Statement::Query(query) = &mut statement {
255+
Self::rewrite_unqualified_identifiers(query);
256+
}
257+
258+
statement
259+
}
260+
}
261+
262+
/// Remove datafusion unsupported type annotations
263+
#[derive(Debug)]
264+
pub struct RemoveUnsupportedTypes {
265+
unsupported_types: HashSet<String>,
266+
}
267+
268+
impl RemoveUnsupportedTypes {
269+
pub fn new() -> Self {
270+
let mut unsupported_types = HashSet::new();
271+
unsupported_types.insert("regclass".to_owned());
272+
273+
Self { unsupported_types }
274+
}
275+
276+
fn rewrite_expr_unsupported_types(&self, expr: &mut Expr) {
277+
match expr {
278+
// This is the key part: identify constants with type annotations.
279+
Expr::TypedString { value, data_type } => {
280+
if self
281+
.unsupported_types
282+
.contains(data_type.to_string().to_lowercase().as_str())
283+
{
284+
*expr =
285+
Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
286+
}
287+
}
288+
Expr::Cast {
289+
data_type,
290+
expr: value,
291+
..
292+
} => {
293+
if self
294+
.unsupported_types
295+
.contains(data_type.to_string().to_lowercase().as_str())
296+
{
297+
*expr = *value.clone();
298+
}
299+
}
300+
// Handle binary operations by recursively rewriting both sides.
301+
Expr::BinaryOp { left, right, .. } => {
302+
self.rewrite_expr_unsupported_types(left);
303+
self.rewrite_expr_unsupported_types(right);
304+
}
305+
// Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
306+
_ => {}
307+
}
308+
}
309+
}
310+
311+
impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
312+
fn rewrite(&self, mut s: Statement) -> Statement {
313+
// Traverse the AST to find the WHERE clause and rewrite it.
314+
if let Statement::Query(query) = &mut s {
315+
if let SetExpr::Select(select) = query.body.as_mut() {
316+
if let Some(expr) = &mut select.selection {
317+
self.rewrite_expr_unsupported_types(expr);
318+
}
319+
}
320+
}
321+
322+
s
323+
}
324+
}
325+
138326
#[cfg(test)]
139327
mod tests {
140328
use super::*;
@@ -180,4 +368,61 @@ mod tests {
180368
"SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname"
181369
);
182370
}
371+
372+
#[test]
373+
fn test_qualifier_prepend() {
374+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
375+
vec![Arc::new(ResolveUnqualifiedIdentifer)];
376+
377+
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname";
378+
let statement = parse(sql).expect("Failed to parse").remove(0);
379+
380+
let statement = rewrite(statement, &rules);
381+
assert_eq!(
382+
statement.to_string(),
383+
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
384+
);
385+
386+
let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname";
387+
let statement = parse(sql).expect("Failed to parse").remove(0);
388+
389+
let statement = rewrite(statement, &rules);
390+
assert_eq!(
391+
statement.to_string(),
392+
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
393+
);
394+
395+
let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
396+
let statement = parse(sql).expect("Failed to parse").remove(0);
397+
398+
let statement = rewrite(statement, &rules);
399+
assert_eq!(
400+
statement.to_string(),
401+
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
402+
);
403+
}
404+
405+
#[test]
406+
fn test_remove_unsupported_types() {
407+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
408+
vec![Arc::new(RemoveUnsupportedTypes::new())];
409+
410+
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname";
411+
let statement = parse(sql).expect("Failed to parse").remove(0);
412+
413+
let statement = rewrite(statement, &rules);
414+
assert_eq!(
415+
statement.to_string(),
416+
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
417+
);
418+
419+
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname";
420+
let statement = parse(sql).expect("Failed to parse").remove(0);
421+
422+
let statement = rewrite(statement, &rules);
423+
assert_eq!(
424+
statement.to_string(),
425+
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
426+
);
427+
}
183428
}

0 commit comments

Comments
 (0)