Skip to content

Commit 248ed92

Browse files
committed
feat: add rewrite to query handlers
1 parent 0cd06ea commit 248ed92

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ use std::collections::HashMap;
22
use std::sync::Arc;
33

44
use crate::auth::{AuthManager, Permission, ResourceType};
5+
use crate::sql::{parse, rewrite, AliasDuplicatedProjectionRewrite, SqlStatementRewriteRule};
56
use async_trait::async_trait;
67
use datafusion::arrow::datatypes::DataType;
78
use datafusion::logical_expr::LogicalPlan;
89
use datafusion::prelude::*;
10+
use datafusion::sql::parser::Statement;
911
use pgwire::api::auth::noop::NoopStartupHandler;
1012
use pgwire::api::auth::StartupHandler;
1113
use pgwire::api::portal::{Format, Portal};
@@ -63,21 +65,26 @@ pub struct DfSessionService {
6365
parser: Arc<Parser>,
6466
timezone: Arc<Mutex<String>>,
6567
auth_manager: Arc<AuthManager>,
68+
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
6669
}
6770

6871
impl DfSessionService {
6972
pub fn new(
7073
session_context: Arc<SessionContext>,
7174
auth_manager: Arc<AuthManager>,
7275
) -> DfSessionService {
76+
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
77+
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
7378
let parser = Arc::new(Parser {
7479
session_context: session_context.clone(),
80+
sql_rewrite_rules: sql_rewrite_rules.clone(),
7581
});
7682
DfSessionService {
7783
session_context,
7884
parser,
7985
timezone: Arc::new(Mutex::new("UTC".to_string())),
8086
auth_manager,
87+
sql_rewrite_rules,
8188
}
8289
}
8390

@@ -308,8 +315,17 @@ impl SimpleQueryHandler for DfSessionService {
308315
where
309316
C: ClientInfo + Unpin + Send + Sync,
310317
{
311-
let query_lower = query.to_lowercase().trim().to_string();
312318
log::debug!("Received query: {query}"); // Log the query for debugging
319+
let mut statements = parse(query).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
320+
321+
// TODO: deal with multiple statements
322+
let mut statement = statements.remove(0);
323+
324+
// Attempt to rewrite
325+
statement = rewrite(statement, &self.sql_rewrite_rules);
326+
327+
// TODO: improve statement check by using statement directly
328+
let query_lower = statement.to_string().to_lowercase().trim().to_string();
313329

314330
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
315331
if !query_lower.starts_with("set")
@@ -526,6 +542,7 @@ impl ExtendedQueryHandler for DfSessionService {
526542

527543
pub struct Parser {
528544
session_context: Arc<SessionContext>,
545+
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
529546
}
530547

531548
#[async_trait]
@@ -538,14 +555,23 @@ impl QueryParser for Parser {
538555
sql: &str,
539556
_types: &[Type],
540557
) -> PgWireResult<Self::Statement> {
541-
log::debug!("Received parse extended query: {sql}"); // Log for debugging
558+
log::debug!("Received parse extended query: {sql}"); // Log for
559+
// debugging
560+
let mut statements = parse(sql).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
561+
let mut statement = statements.remove(0);
562+
563+
// Attempt to rewrite
564+
statement = rewrite(statement, &self.sql_rewrite_rules);
565+
566+
let query = statement.to_string();
567+
542568
let context = &self.session_context;
543569
let state = context.state();
544570
let logical_plan = state
545-
.create_logical_plan(sql)
571+
.statement_to_plan(Statement::Statement(Box::new(statement)))
546572
.await
547573
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
548-
Ok((sql.to_string(), logical_plan))
574+
Ok((query, logical_plan))
549575
}
550576
}
551577

datafusion-postgres/src/sql.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use datafusion::sql::sqlparser::ast::Expr;
24
use datafusion::sql::sqlparser::ast::Ident;
35
use datafusion::sql::sqlparser::ast::Select;
@@ -15,15 +17,15 @@ pub fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
1517
Parser::parse_sql(&dialect, sql)
1618
}
1719

18-
pub fn rewrite(mut s: Statement, rules: &[Box<dyn SqlStatementRewriteRule>]) -> Statement {
20+
pub fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
1921
for rule in rules {
2022
s = rule.rewrite(s);
2123
}
2224

2325
s
2426
}
2527

26-
pub trait SqlStatementRewriteRule {
28+
pub trait SqlStatementRewriteRule: Send + Sync {
2729
fn rewrite(&self, s: Statement) -> Statement;
2830
}
2931

@@ -35,7 +37,8 @@ pub trait SqlStatementRewriteRule {
3537
///
3638
/// This rule will add alias to column, when there is a wildcard found in
3739
/// projection.
38-
struct AliasDuplicatedProjectionRewrite;
40+
#[derive(Debug)]
41+
pub struct AliasDuplicatedProjectionRewrite;
3942

4043
impl AliasDuplicatedProjectionRewrite {
4144
// Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard.
@@ -103,7 +106,7 @@ impl AliasDuplicatedProjectionRewrite {
103106
};
104107

105108
if let Some(name) = alias_partial {
106-
let alias = format!("__alias_{}", name);
109+
let alias = format!("__alias_{name}");
107110
new_projection.push(SelectItem::ExprWithAlias {
108111
expr,
109112
alias: Ident::new(alias),
@@ -138,8 +141,8 @@ mod tests {
138141

139142
#[test]
140143
fn test_alias_rewrite() {
141-
let rules: Vec<Box<dyn SqlStatementRewriteRule>> =
142-
vec![Box::new(AliasDuplicatedProjectionRewrite)];
144+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
145+
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
143146

144147
let sql = "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n";
145148
let statement = parse(sql).expect("Failed to parse").remove(0);

0 commit comments

Comments
 (0)