@@ -5,7 +5,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema};
55use datafusion:: common:: { ParamValues , ToDFSchema } ;
66use datafusion:: logical_expr:: LogicalPlan ;
77use datafusion:: prelude:: SessionContext ;
8- use datafusion:: sql:: sqlparser:: ast:: Statement ;
8+ use datafusion:: sql:: sqlparser:: ast:: { Set , Statement } ;
99use log:: { info, warn} ;
1010use pgwire:: api:: results:: { DataRowEncoder , FieldFormat , FieldInfo , QueryResponse , Response , Tag } ;
1111use pgwire:: api:: ClientInfo ;
@@ -29,10 +29,7 @@ impl QueryHook for SetShowHook {
2929 ) -> Option < PgWireResult < Response > > {
3030 match statement {
3131 Statement :: Set { .. } => {
32- let query = statement. to_string ( ) ;
33- let query_lower = query. to_lowercase ( ) ;
34-
35- try_respond_set_statements ( client, & query_lower, session_context) . await
32+ try_respond_set_statements ( client, & statement, session_context) . await
3633 }
3734 Statement :: ShowVariable { .. } | Statement :: ShowStatus { .. } => {
3835 let query = statement. to_string ( ) ;
@@ -93,10 +90,7 @@ impl QueryHook for SetShowHook {
9390 ) -> Option < PgWireResult < Response > > {
9491 match statement {
9592 Statement :: Set { .. } => {
96- let query = statement. to_string ( ) ;
97- let query_lower = query. to_lowercase ( ) ;
98-
99- try_respond_set_statements ( client, & query_lower, session_context) . await
93+ try_respond_set_statements ( client, & statement, session_context) . await
10094 }
10195 Statement :: ShowVariable { .. } | Statement :: ShowStatus { .. } => {
10296 let query = statement. to_string ( ) ;
@@ -130,84 +124,75 @@ fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
130124
131125async fn try_respond_set_statements < C > (
132126 client : & mut C ,
133- query_lower : & str ,
127+ statement : & Statement ,
134128 session_context : & SessionContext ,
135129) -> Option < PgWireResult < Response > >
136130where
137131 C : ClientInfo + Send + Sync + ?Sized ,
138132{
139- if query_lower. starts_with ( "set" ) {
140- let result = if query_lower. starts_with ( "set time zone" ) {
141- let parts: Vec < & str > = query_lower. split_whitespace ( ) . collect ( ) ;
142- if parts. len ( ) >= 4 {
143- let tz = parts[ 3 ] . trim_matches ( '"' ) ;
144- client:: set_timezone ( client, Some ( tz) ) ;
145- Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) )
146- } else {
147- Err ( PgWireError :: UserError ( Box :: new (
148- pgwire:: error:: ErrorInfo :: new (
149- "ERROR" . to_string ( ) ,
150- "42601" . to_string ( ) ,
151- "Invalid SET TIME ZONE syntax" . to_string ( ) ,
152- ) ,
153- ) ) )
154- }
155- } else if query_lower. starts_with ( "set statement_timeout" ) {
156- let parts: Vec < & str > = query_lower. split_whitespace ( ) . collect ( ) ;
157- if parts. len ( ) >= 3 {
158- let timeout_str = parts[ 2 ] . trim_matches ( '"' ) . trim_matches ( '\'' ) ;
133+ let Statement :: Set ( set_statement) = statement else {
134+ return None ;
135+ } ;
159136
160- let timeout = if timeout_str == "0" || timeout_str. is_empty ( ) {
161- None
137+ match & set_statement {
138+ Set :: SingleAssignment {
139+ scope : None ,
140+ hivevar : false ,
141+ variable,
142+ values,
143+ } if & variable. to_string ( ) == "statement_timeout" => {
144+ let value = values[ 0 ] . to_string ( ) ;
145+ let timeout_str = value. trim_matches ( '"' ) . trim_matches ( '\'' ) ;
146+
147+ let timeout = if timeout_str == "0" || timeout_str. is_empty ( ) {
148+ None
149+ } else {
150+ // Parse timeout value (supports ms, s, min formats)
151+ let timeout_ms = if timeout_str. ends_with ( "ms" ) {
152+ timeout_str. trim_end_matches ( "ms" ) . parse :: < u64 > ( )
153+ } else if timeout_str. ends_with ( "s" ) {
154+ timeout_str
155+ . trim_end_matches ( "s" )
156+ . parse :: < u64 > ( )
157+ . map ( |s| s * 1000 )
158+ } else if timeout_str. ends_with ( "min" ) {
159+ timeout_str
160+ . trim_end_matches ( "min" )
161+ . parse :: < u64 > ( )
162+ . map ( |m| m * 60 * 1000 )
162163 } else {
163- // Parse timeout value (supports ms, s, min formats)
164- let timeout_ms = if timeout_str. ends_with ( "ms" ) {
165- timeout_str. trim_end_matches ( "ms" ) . parse :: < u64 > ( )
166- } else if timeout_str. ends_with ( "s" ) {
167- timeout_str
168- . trim_end_matches ( "s" )
169- . parse :: < u64 > ( )
170- . map ( |s| s * 1000 )
171- } else if timeout_str. ends_with ( "min" ) {
172- timeout_str
173- . trim_end_matches ( "min" )
174- . parse :: < u64 > ( )
175- . map ( |m| m * 60 * 1000 )
176- } else {
177- // Default to milliseconds
178- timeout_str. parse :: < u64 > ( )
179- } ;
180-
181- match timeout_ms {
182- Ok ( ms) if ms > 0 => Some ( std:: time:: Duration :: from_millis ( ms) ) ,
183- _ => None ,
184- }
164+ // Default to milliseconds
165+ timeout_str. parse :: < u64 > ( )
185166 } ;
186167
187- client:: set_statement_timeout ( client, timeout) ;
188- Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) )
189- } else {
190- Err ( PgWireError :: UserError ( Box :: new (
191- pgwire:: error:: ErrorInfo :: new (
192- "ERROR" . to_string ( ) ,
193- "42601" . to_string ( ) ,
194- "Invalid SET statement_timeout syntax" . to_string ( ) ,
195- ) ,
196- ) ) )
197- }
198- } else {
168+ match timeout_ms {
169+ Ok ( ms) if ms > 0 => Some ( std:: time:: Duration :: from_millis ( ms) ) ,
170+ _ => None ,
171+ }
172+ } ;
173+
174+ client:: set_statement_timeout ( client, timeout) ;
175+ Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
176+ }
177+ Set :: SetTimeZone {
178+ local : false ,
179+ value,
180+ } => {
181+ let tz = value. to_string ( ) ;
182+ let tz = tz. trim_matches ( '"' ) . trim_matches ( '\'' ) ;
183+ client:: set_timezone ( client, Some ( & tz) ) ;
184+ Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
185+ }
186+ _ => {
199187 // pass SET query to datafusion
200- if let Err ( e) = session_context. sql ( query_lower) . await {
201- warn ! ( "SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored" ) ;
188+ let query = statement. to_string ( ) ;
189+ if let Err ( e) = session_context. sql ( & query) . await {
190+ warn ! ( "SET statement {query} is not supported by datafusion, error {e}, statement ignored" ) ;
202191 }
203192
204193 // Always return SET success
205- Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) )
206- } ;
207-
208- Some ( result)
209- } else {
210- None
194+ Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
195+ }
211196 }
212197}
213198
@@ -266,6 +251,8 @@ where
266251mod tests {
267252 use std:: time:: Duration ;
268253
254+ use datafusion:: sql:: sqlparser:: { dialect:: PostgreSqlDialect , parser:: Parser } ;
255+
269256 use super :: * ;
270257 use crate :: testing:: MockClient ;
271258
@@ -275,12 +262,13 @@ mod tests {
275262 let mut client = MockClient :: new ( ) ;
276263
277264 // Test setting timeout to 5000ms
278- let set_response = try_respond_set_statements (
279- & mut client,
280- "set statement_timeout '5000ms'" ,
281- & session_context,
282- )
283- . await ;
265+ let statement = Parser :: new ( & PostgreSqlDialect { } )
266+ . try_with_sql ( "set statement_timeout to '5000ms'" )
267+ . unwrap ( )
268+ . parse_statement ( )
269+ . unwrap ( ) ;
270+ let set_response =
271+ try_respond_set_statements ( & mut client, & statement, & session_context) . await ;
284272
285273 assert ! ( set_response. is_some( ) ) ;
286274 assert ! ( set_response. unwrap( ) . is_ok( ) ) ;
@@ -303,19 +291,22 @@ mod tests {
303291 let mut client = MockClient :: new ( ) ;
304292
305293 // Set timeout first
306- let resp = try_respond_set_statements (
307- & mut client ,
308- "set statement_timeout '1000ms'" ,
309- & session_context ,
310- )
311- . await ;
294+ let statement = Parser :: new ( & PostgreSqlDialect { } )
295+ . try_with_sql ( "set statement_timeout to '1000ms'" )
296+ . unwrap ( )
297+ . parse_statement ( )
298+ . unwrap ( ) ;
299+ let resp = try_respond_set_statements ( & mut client , & statement , & session_context ) . await ;
312300 assert ! ( resp. is_some( ) ) ;
313301 assert ! ( resp. unwrap( ) . is_ok( ) ) ;
314302
315303 // Disable timeout with 0
316- let resp =
317- try_respond_set_statements ( & mut client, "set statement_timeout '0'" , & session_context)
318- . await ;
304+ let statement = Parser :: new ( & PostgreSqlDialect { } )
305+ . try_with_sql ( "set statement_timeout to '0'" )
306+ . unwrap ( )
307+ . parse_statement ( )
308+ . unwrap ( ) ;
309+ let resp = try_respond_set_statements ( & mut client, & statement, & session_context) . await ;
319310 assert ! ( resp. is_some( ) ) ;
320311 assert ! ( resp. unwrap( ) . is_ok( ) ) ;
321312
0 commit comments