@@ -3,9 +3,10 @@ use std::sync::Arc;
33use async_trait:: async_trait;
44use datafusion:: arrow:: datatypes:: { DataType , Field , Schema } ;
55use datafusion:: common:: { ParamValues , ToDFSchema } ;
6+ use datafusion:: error:: DataFusionError ;
67use datafusion:: logical_expr:: LogicalPlan ;
78use datafusion:: prelude:: SessionContext ;
8- use datafusion:: sql:: sqlparser:: ast:: { Set , Statement } ;
9+ use datafusion:: sql:: sqlparser:: ast:: { Expr , Set , Statement } ;
910use log:: { info, warn} ;
1011use pgwire:: api:: results:: { DataRowEncoder , FieldFormat , FieldInfo , QueryResponse , Response , Tag } ;
1112use pgwire:: api:: ClientInfo ;
@@ -134,39 +135,53 @@ where
134135 hivevar : false ,
135136 variable,
136137 values,
137- } if & variable. to_string ( ) == "statement_timeout" => {
138- let value = values[ 0 ] . to_string ( ) ;
139- let timeout_str = value. trim_matches ( '"' ) . trim_matches ( '\'' ) ;
140-
141- let timeout = if timeout_str == "0" || timeout_str. is_empty ( ) {
142- None
143- } else {
144- // Parse timeout value (supports ms, s, min formats)
145- let timeout_ms = if timeout_str. ends_with ( "ms" ) {
146- timeout_str. trim_end_matches ( "ms" ) . parse :: < u64 > ( )
147- } else if timeout_str. ends_with ( "s" ) {
148- timeout_str
149- . trim_end_matches ( "s" )
150- . parse :: < u64 > ( )
151- . map ( |s| s * 1000 )
152- } else if timeout_str. ends_with ( "min" ) {
153- timeout_str
154- . trim_end_matches ( "min" )
155- . parse :: < u64 > ( )
156- . map ( |m| m * 60 * 1000 )
138+ } => {
139+ let var = variable. to_string ( ) . to_lowercase ( ) ;
140+ if var == "statement_timeout" {
141+ let value = values[ 0 ] . to_string ( ) ;
142+ let timeout_str = value. trim_matches ( '"' ) . trim_matches ( '\'' ) ;
143+
144+ let timeout = if timeout_str == "0" || timeout_str. is_empty ( ) {
145+ None
157146 } else {
158- // Default to milliseconds
159- timeout_str. parse :: < u64 > ( )
147+ // Parse timeout value (supports ms, s, min formats)
148+ let timeout_ms = if timeout_str. ends_with ( "ms" ) {
149+ timeout_str. trim_end_matches ( "ms" ) . parse :: < u64 > ( )
150+ } else if timeout_str. ends_with ( "s" ) {
151+ timeout_str
152+ . trim_end_matches ( "s" )
153+ . parse :: < u64 > ( )
154+ . map ( |s| s * 1000 )
155+ } else if timeout_str. ends_with ( "min" ) {
156+ timeout_str
157+ . trim_end_matches ( "min" )
158+ . parse :: < u64 > ( )
159+ . map ( |m| m * 60 * 1000 )
160+ } else {
161+ // Default to milliseconds
162+ timeout_str. parse :: < u64 > ( )
163+ } ;
164+
165+ match timeout_ms {
166+ Ok ( ms) if ms > 0 => Some ( std:: time:: Duration :: from_millis ( ms) ) ,
167+ _ => None ,
168+ }
160169 } ;
161170
162- match timeout_ms {
163- Ok ( ms) if ms > 0 => Some ( std:: time:: Duration :: from_millis ( ms) ) ,
164- _ => None ,
171+ client:: set_statement_timeout ( client, timeout) ;
172+ return Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) ) ;
173+ } else if matches ! ( var. as_str( ) , "datestyle" | "bytea_output" | "intervalstyle" ) {
174+ if values. len ( ) > 0 {
175+ // postgres configuration variables
176+ let value = values[ 0 ] . clone ( ) ;
177+ if let Expr :: Value ( value) = value {
178+ client
179+ . metadata_mut ( )
180+ . insert ( var, value. into_string ( ) . unwrap_or_else ( || "" . to_string ( ) ) ) ;
181+ return Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) ) ;
182+ }
165183 }
166- } ;
167-
168- client:: set_statement_timeout ( client, timeout) ;
169- Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
184+ }
170185 }
171186 Set :: SetTimeZone {
172187 local : false ,
@@ -175,19 +190,39 @@ where
175190 let tz = value. to_string ( ) ;
176191 let tz = tz. trim_matches ( '"' ) . trim_matches ( '\'' ) ;
177192 client:: set_timezone ( client, Some ( tz) ) ;
178- Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
193+ return Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) ) ;
179194 }
180- _ => {
181- // pass SET query to datafusion
182- let query = statement. to_string ( ) ;
183- if let Err ( e) = session_context. sql ( & query) . await {
184- warn ! ( "SET statement {query} is not supported by datafusion, error {e}, statement ignored" ) ;
185- }
195+ _ => { }
196+ }
186197
187- // Always return SET success
188- Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
189- }
198+ // fallback to datafusion and ignore all errors
199+ if let Err ( e) = execute_set_statement ( session_context, statement. clone ( ) ) . await {
200+ warn ! (
201+ "SET statement {} is not supported by datafusion, error {e}, statement ignored" ,
202+ statement. to_string( )
203+ ) ;
190204 }
205+
206+ // Always return SET success
207+ Some ( Ok ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
208+ }
209+
210+ async fn execute_set_statement (
211+ session_context : & SessionContext ,
212+ statement : Statement ,
213+ ) -> Result < ( ) , DataFusionError > {
214+ let state = session_context. state ( ) ;
215+ let logical_plan = state
216+ . statement_to_plan ( datafusion:: sql:: parser:: Statement :: Statement ( Box :: new (
217+ statement,
218+ ) ) )
219+ . await
220+ . and_then ( |logical_plan| state. optimize ( & logical_plan) ) ?;
221+
222+ session_context
223+ . execute_logical_plan ( logical_plan)
224+ . await
225+ . map ( |_| ( ) )
191226}
192227
193228async fn try_respond_show_statements < C > (
@@ -204,10 +239,11 @@ where
204239
205240 let variables = variable
206241 . iter ( )
207- . map ( |v| & v. value as & str )
242+ . map ( |v| v. value . to_lowercase ( ) )
208243 . collect :: < Vec < _ > > ( ) ;
244+ let variables_ref = variables. iter ( ) . map ( |s| s. as_str ( ) ) . collect :: < Vec < _ > > ( ) ;
209245
210- match & variables as & [ & str ] {
246+ match variables_ref . as_slice ( ) {
211247 [ "time" , "zone" ] => {
212248 let timezone = client:: get_timezone ( client) . unwrap_or ( "UTC" ) ;
213249 Some ( mock_show_response ( "TimeZone" , timezone) . map ( Response :: Query ) )
@@ -238,6 +274,14 @@ where
238274 [ "transaction" , "isolation" , "level" ] => {
239275 Some ( mock_show_response ( "transaction_isolation" , "read_committed" ) . map ( Response :: Query ) )
240276 }
277+ [ "bytea_output" ] | [ "datestyle" ] | [ "intervalstyle" ] => {
278+ let val = client
279+ . metadata ( )
280+ . get ( & variables[ 0 ] )
281+ . map ( |v| v. as_str ( ) )
282+ . unwrap_or ( "" ) ;
283+ Some ( mock_show_response ( & variables[ 0 ] , val) . map ( Response :: Query ) )
284+ }
241285 _ => {
242286 info ! ( "Unsupported show statement: {}" , statement) ;
243287 Some ( mock_show_response ( "unsupported_show_statement" , "" ) . map ( Response :: Query ) )
@@ -288,6 +332,74 @@ mod tests {
288332 assert ! ( show_response. unwrap( ) . is_ok( ) ) ;
289333 }
290334
335+ #[ tokio:: test]
336+ async fn test_bytea_output_set_and_show ( ) {
337+ let session_context = SessionContext :: new ( ) ;
338+ let mut client = MockClient :: new ( ) ;
339+
340+ // Test setting timeout to 5000ms
341+ let statement = Parser :: new ( & PostgreSqlDialect { } )
342+ . try_with_sql ( "set bytea_output = 'hex'" )
343+ . unwrap ( )
344+ . parse_statement ( )
345+ . unwrap ( ) ;
346+ let set_response =
347+ try_respond_set_statements ( & mut client, & statement, & session_context) . await ;
348+
349+ assert ! ( set_response. is_some( ) ) ;
350+ assert ! ( set_response. unwrap( ) . is_ok( ) ) ;
351+
352+ // Verify the timeout was set in client metadata
353+ let bytea_output = client. metadata ( ) . get ( "bytea_output" ) . unwrap ( ) ;
354+ assert_eq ! ( bytea_output, "hex" ) ;
355+
356+ // Test SHOW statement_timeout
357+ let statement = Parser :: new ( & PostgreSqlDialect { } )
358+ . try_with_sql ( "show bytea_output" )
359+ . unwrap ( )
360+ . parse_statement ( )
361+ . unwrap ( ) ;
362+ let show_response =
363+ try_respond_show_statements ( & client, & statement, & session_context) . await ;
364+
365+ assert ! ( show_response. is_some( ) ) ;
366+ assert ! ( show_response. unwrap( ) . is_ok( ) ) ;
367+ }
368+
369+ #[ tokio:: test]
370+ async fn test_date_style_set_and_show ( ) {
371+ let session_context = SessionContext :: new ( ) ;
372+ let mut client = MockClient :: new ( ) ;
373+
374+ // Test setting timeout to 5000ms
375+ let statement = Parser :: new ( & PostgreSqlDialect { } )
376+ . try_with_sql ( "set dateStyle = 'ISO, DMY'" )
377+ . unwrap ( )
378+ . parse_statement ( )
379+ . unwrap ( ) ;
380+ let set_response =
381+ try_respond_set_statements ( & mut client, & statement, & session_context) . await ;
382+
383+ assert ! ( set_response. is_some( ) ) ;
384+ assert ! ( set_response. unwrap( ) . is_ok( ) ) ;
385+
386+ // Verify the timeout was set in client metadata
387+ let bytea_output = client. metadata ( ) . get ( "datestyle" ) . unwrap ( ) ;
388+ assert_eq ! ( bytea_output, "ISO, DMY" ) ;
389+
390+ // Test SHOW statement_timeout
391+ let statement = Parser :: new ( & PostgreSqlDialect { } )
392+ . try_with_sql ( "show dateStyle" )
393+ . unwrap ( )
394+ . parse_statement ( )
395+ . unwrap ( ) ;
396+ let show_response =
397+ try_respond_show_statements ( & client, & statement, & session_context) . await ;
398+
399+ assert ! ( show_response. is_some( ) ) ;
400+ assert ! ( show_response. unwrap( ) . is_ok( ) ) ;
401+ }
402+
291403 #[ tokio:: test]
292404 async fn test_statement_timeout_disable ( ) {
293405 let session_context = SessionContext :: new ( ) ;
0 commit comments