@@ -2,10 +2,12 @@ use std::collections::HashMap;
22use std:: sync:: Arc ;
33
44use crate :: auth:: { AuthManager , Permission , ResourceType } ;
5+ use crate :: sql:: { parse, rewrite, AliasDuplicatedProjectionRewrite , SqlStatementRewriteRule } ;
56use async_trait:: async_trait;
67use datafusion:: arrow:: datatypes:: DataType ;
78use datafusion:: logical_expr:: LogicalPlan ;
89use datafusion:: prelude:: * ;
10+ use datafusion:: sql:: parser:: Statement ;
911use pgwire:: api:: auth:: noop:: NoopStartupHandler ;
1012use pgwire:: api:: auth:: StartupHandler ;
1113use pgwire:: api:: portal:: { Format , Portal } ;
@@ -63,21 +65,26 @@ pub struct DfSessionService {
6365 parser : Arc < Parser > ,
6466 timezone : Arc < Mutex < String > > ,
6567 auth_manager : Arc < AuthManager > ,
68+ sql_rewrite_rules : Vec < Arc < dyn SqlStatementRewriteRule > > ,
6669}
6770
6871impl DfSessionService {
6972 pub fn new (
7073 session_context : Arc < SessionContext > ,
7174 auth_manager : Arc < AuthManager > ,
7275 ) -> DfSessionService {
76+ let sql_rewrite_rules: Vec < Arc < dyn SqlStatementRewriteRule > > =
77+ vec ! [ Arc :: new( AliasDuplicatedProjectionRewrite ) ] ;
7378 let parser = Arc :: new ( Parser {
7479 session_context : session_context. clone ( ) ,
80+ sql_rewrite_rules : sql_rewrite_rules. clone ( ) ,
7581 } ) ;
7682 DfSessionService {
7783 session_context,
7884 parser,
7985 timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
8086 auth_manager,
87+ sql_rewrite_rules,
8188 }
8289 }
8390
@@ -308,8 +315,17 @@ impl SimpleQueryHandler for DfSessionService {
308315 where
309316 C : ClientInfo + Unpin + Send + Sync ,
310317 {
311- let query_lower = query. to_lowercase ( ) . trim ( ) . to_string ( ) ;
312318 log:: debug!( "Received query: {query}" ) ; // Log the query for debugging
319+ let mut statements = parse ( query) . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
320+
321+ // TODO: deal with multiple statements
322+ let mut statement = statements. remove ( 0 ) ;
323+
324+ // Attempt to rewrite
325+ statement = rewrite ( statement, & self . sql_rewrite_rules ) ;
326+
327+ // TODO: improve statement check by using statement directly
328+ let query_lower = statement. to_string ( ) . to_lowercase ( ) . trim ( ) . to_string ( ) ;
313329
314330 // Check permissions for the query (skip for SET, transaction, and SHOW statements)
315331 if !query_lower. starts_with ( "set" )
@@ -526,6 +542,7 @@ impl ExtendedQueryHandler for DfSessionService {
526542
527543pub struct Parser {
528544 session_context : Arc < SessionContext > ,
545+ sql_rewrite_rules : Vec < Arc < dyn SqlStatementRewriteRule > > ,
529546}
530547
531548#[ async_trait]
@@ -538,14 +555,23 @@ impl QueryParser for Parser {
538555 sql : & str ,
539556 _types : & [ Type ] ,
540557 ) -> PgWireResult < Self :: Statement > {
541- log:: debug!( "Received parse extended query: {sql}" ) ; // Log for debugging
558+ log:: debug!( "Received parse extended query: {sql}" ) ; // Log for
559+ // debugging
560+ let mut statements = parse ( sql) . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
561+ let mut statement = statements. remove ( 0 ) ;
562+
563+ // Attempt to rewrite
564+ statement = rewrite ( statement, & self . sql_rewrite_rules ) ;
565+
566+ let query = statement. to_string ( ) ;
567+
542568 let context = & self . session_context ;
543569 let state = context. state ( ) ;
544570 let logical_plan = state
545- . create_logical_plan ( sql )
571+ . statement_to_plan ( Statement :: Statement ( Box :: new ( statement ) ) )
546572 . await
547573 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
548- Ok ( ( sql . to_string ( ) , logical_plan) )
574+ Ok ( ( query , logical_plan) )
549575 }
550576}
551577
0 commit comments