1+ // src/handlers.rs
12use std:: collections:: HashMap ;
23use std:: sync:: Arc ;
34
@@ -15,8 +16,14 @@ use pgwire::api::stmt::StoredStatement;
1516use pgwire:: api:: { ClientInfo , NoopErrorHandler , PgWireServerHandlers , Type } ;
1617use pgwire:: error:: { PgWireError , PgWireResult } ;
1718
19+ // --- ADD THESE IMPORTS FOR MULTI-STATEMENT PARSING ---
20+ use sqlparser:: dialect:: GenericDialect ;
21+ use sqlparser:: parser:: Parser as SqlParser ;
22+ // ------------------------------------------------------
23+
1824use crate :: datatypes:: { self , into_pg_type} ;
1925
26+ /// A factory that creates our handlers for the PGWire server.
2027pub struct HandlerFactory ( pub Arc < DfSessionService > ) ;
2128
2229impl NoopStartupHandler for DfSessionService { }
@@ -49,9 +56,10 @@ impl PgWireServerHandlers for HandlerFactory {
4956 }
5057}
5158
59+ /// Our primary session service, storing a DataFusion `SessionContext`.
5260pub struct DfSessionService {
53- session_context : Arc < SessionContext > ,
54- parser : Arc < Parser > ,
61+ pub session_context : Arc < SessionContext > ,
62+ pub parser : Arc < Parser > ,
5563}
5664
5765impl DfSessionService {
@@ -67,27 +75,7 @@ impl DfSessionService {
6775 }
6876}
6977
70- #[ async_trait]
71- impl SimpleQueryHandler for DfSessionService {
72- async fn do_query < ' a , C > (
73- & self ,
74- _client : & mut C ,
75- query : & ' a str ,
76- ) -> PgWireResult < Vec < Response < ' a > > >
77- where
78- C : ClientInfo + Unpin + Send + Sync ,
79- {
80- let ctx = & self . session_context ;
81- let df = ctx
82- . sql ( query)
83- . await
84- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
85-
86- let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
87- Ok ( vec ! [ Response :: Query ( resp) ] )
88- }
89- }
90-
78+ /// A simple parser that builds a logical plan from SQL text, using DataFusion.
9179pub struct Parser {
9280 session_context : Arc < SessionContext > ,
9381}
@@ -104,18 +92,68 @@ impl QueryParser for Parser {
10492 . create_logical_plan ( sql)
10593 . await
10694 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
107- let optimised = state
95+ let optimized = state
10896 . optimize ( & logical_plan)
10997 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
11098
111- Ok ( optimised )
99+ Ok ( optimized )
112100 }
113101}
114102
103+ // ----------------------------------------------------------------
104+ // SimpleQueryHandler Implementation (multi-statement support)
105+ // ----------------------------------------------------------------
106+ #[ async_trait]
107+ impl SimpleQueryHandler for DfSessionService {
108+ async fn do_query < ' a , C > (
109+ & self ,
110+ _client : & mut C ,
111+ query : & ' a str ,
112+ ) -> PgWireResult < Vec < Response < ' a > > >
113+ where
114+ C : ClientInfo + Unpin + Send + Sync ,
115+ {
116+ // 1) Parse the incoming query string into multiple statements using sqlparser.
117+ let dialect = GenericDialect { } ;
118+ let stmts = match SqlParser :: parse_sql ( & dialect, query) {
119+ Ok ( s) => s,
120+ Err ( e) => {
121+ return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
122+ }
123+ } ;
124+
125+ // 2) For each parsed statement, execute with DataFusion and collect results.
126+ let mut responses = Vec :: with_capacity ( stmts. len ( ) ) ;
127+ for statement in stmts {
128+ // Convert the AST statement back to SQL text
129+ // (some statements might be empty if there's a trailing semicolon)
130+ let stmt_string = statement. to_string ( ) . trim ( ) . to_owned ( ) ;
131+ if stmt_string. is_empty ( ) {
132+ continue ;
133+ }
134+
135+ // Execute the statement in DataFusion
136+ let df = self
137+ . session_context
138+ . sql ( & stmt_string)
139+ . await
140+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
141+
142+ // 3) Encode the DataFrame into a QueryResponse for the client
143+ let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
144+ responses. push ( Response :: Query ( resp) ) ;
145+ }
146+
147+ Ok ( responses)
148+ }
149+ }
150+
151+ // ----------------------------------------------------------------
152+ // ExtendedQueryHandler Implementation (same as original)
153+ // ----------------------------------------------------------------
115154#[ async_trait]
116155impl ExtendedQueryHandler for DfSessionService {
117156 type Statement = LogicalPlan ;
118-
119157 type QueryParser = Parser ;
120158
121159 fn query_parser ( & self ) -> Arc < Self :: QueryParser > {
@@ -201,11 +239,11 @@ impl ExtendedQueryHandler for DfSessionService {
201239 }
202240}
203241
242+ /// Helper to convert DataFusion’s parameter map into an ordered list.
204243fn ordered_param_types ( types : & HashMap < String , Option < DataType > > ) -> Vec < Option < & DataType > > {
205- // Datafusion stores the parameters as a map. In our case, the keys will be
206- // `$1`, `$2` etc. The values will be the parameter types.
207-
208- let mut types = types. iter ( ) . collect :: < Vec < _ > > ( ) ;
209- types. sort_by ( |a, b| a. 0 . cmp ( b. 0 ) ) ;
210- types. into_iter ( ) . map ( |pt| pt. 1 . as_ref ( ) ) . collect ( )
244+ // DataFusion stores parameters as a map keyed by "$1", "$2", etc.
245+ // We sort them in ascending order by key to match the expected param order.
246+ let mut types_vec = types. iter ( ) . collect :: < Vec < _ > > ( ) ;
247+ types_vec. sort_by ( |a, b| a. 0 . cmp ( b. 0 ) ) ;
248+ types_vec. into_iter ( ) . map ( |pt| pt. 1 . as_ref ( ) ) . collect ( )
211249}
0 commit comments