@@ -124,7 +124,7 @@ impl DfSessionService {
124124 * sc_guard = new_context;
125125 Ok ( ( ) )
126126 }
127- "client_encoding"
127+ | "client_encoding"
128128 | "search_path"
129129 | "application_name"
130130 | "datestyle"
@@ -154,7 +154,6 @@ impl DfSessionService {
154154
155155 let sc_guard = self . session_context . read ( ) . await ;
156156 let config = sc_guard. state ( ) . config ( ) . options ( ) . clone ( ) ;
157- drop ( sc_guard) ;
158157
159158 let value = match var_name. as_str ( ) {
160159 "timezone" => config
@@ -232,6 +231,17 @@ impl DfSessionService {
232231 . get ( & var_name)
233232 . cloned ( )
234233 . unwrap_or_else ( || "read committed" . to_string ( ) ) ,
234+
235+ // *** New variables to keep psql happy ***
236+ "server_version" => "14.0" . to_string ( ) ,
237+ "server_version_num" => "140000" . to_string ( ) ,
238+ "server_encoding" => "UTF8" . to_string ( ) ,
239+ "is_superuser" => "off" . to_string ( ) ,
240+ "lc_messages" => "en_US.UTF-8" . to_string ( ) ,
241+ "lc_monetary" => "en_US.UTF-8" . to_string ( ) ,
242+ "lc_numeric" => "en_US.UTF-8" . to_string ( ) ,
243+ "lc_time" => "en_US.UTF-8" . to_string ( ) ,
244+
235245 "all" => {
236246 let mut names = Vec :: new ( ) ;
237247 let mut values = Vec :: new ( ) ;
@@ -240,50 +250,39 @@ impl DfSessionService {
240250 names. push ( "timezone" . to_string ( ) ) ;
241251 values. push ( tz. clone ( ) ) ;
242252 }
253+
243254 let custom_vars = self . custom_session_vars . read ( ) . await ;
244255 for ( name, value) in custom_vars. iter ( ) {
245256 names. push ( name. clone ( ) ) ;
246257 values. push ( value. clone ( ) ) ;
247258 }
248- if !custom_vars. contains_key ( "client_encoding" ) {
249- names. push ( "client_encoding" . to_string ( ) ) ;
250- values. push ( "UTF8" . to_string ( ) ) ;
251- }
252- if !custom_vars. contains_key ( "search_path" ) {
253- names. push ( "search_path" . to_string ( ) ) ;
254- values. push ( "public" . to_string ( ) ) ;
255- }
256- if !custom_vars. contains_key ( "application_name" ) {
257- names. push ( "application_name" . to_string ( ) ) ;
258- values. push ( "" . to_string ( ) ) ;
259- }
260- if !custom_vars. contains_key ( "datestyle" ) {
261- names. push ( "datestyle" . to_string ( ) ) ;
262- values. push ( "ISO, MDY" . to_string ( ) ) ;
263- }
264- if !custom_vars. contains_key ( "client_min_messages" ) {
265- names. push ( "client_min_messages" . to_string ( ) ) ;
266- values. push ( "notice" . to_string ( ) ) ;
267- }
268- if !custom_vars. contains_key ( "extra_float_digits" ) {
269- names. push ( "extra_float_digits" . to_string ( ) ) ;
270- values. push ( "3" . to_string ( ) ) ;
271- }
272- if !custom_vars. contains_key ( "standard_conforming_strings" ) {
273- names. push ( "standard_conforming_strings" . to_string ( ) ) ;
274- values. push ( "on" . to_string ( ) ) ;
275- }
276- if !custom_vars. contains_key ( "check_function_bodies" ) {
277- names. push ( "check_function_bodies" . to_string ( ) ) ;
278- values. push ( "off" . to_string ( ) ) ;
279- }
280- if !custom_vars. contains_key ( "transaction_read_only" ) {
281- names. push ( "transaction_read_only" . to_string ( ) ) ;
282- values. push ( "off" . to_string ( ) ) ;
283- }
284- if !custom_vars. contains_key ( "transaction_isolation" ) {
285- names. push ( "transaction_isolation" . to_string ( ) ) ;
286- values. push ( "read committed" . to_string ( ) ) ;
259+
260+ let defaults = vec ! [
261+ ( "client_encoding" , "UTF8" ) ,
262+ ( "search_path" , "public" ) ,
263+ ( "application_name" , "" ) ,
264+ ( "datestyle" , "ISO, MDY" ) ,
265+ ( "client_min_messages" , "notice" ) ,
266+ ( "extra_float_digits" , "3" ) ,
267+ ( "standard_conforming_strings" , "on" ) ,
268+ ( "check_function_bodies" , "off" ) ,
269+ ( "transaction_read_only" , "off" ) ,
270+ ( "transaction_isolation" , "read committed" ) ,
271+ ( "server_version" , "14.0" ) ,
272+ ( "server_version_num" , "140000" ) ,
273+ ( "server_encoding" , "UTF8" ) ,
274+ ( "is_superuser" , "off" ) ,
275+ ( "lc_messages" , "en_US.UTF-8" ) ,
276+ ( "lc_monetary" , "en_US.UTF-8" ) ,
277+ ( "lc_numeric" , "en_US.UTF-8" ) ,
278+ ( "lc_time" , "en_US.UTF-8" ) ,
279+ ] ;
280+
281+ for ( k, v) in defaults {
282+ if !names. contains ( & k. to_string ( ) ) {
283+ names. push ( k. to_string ( ) ) ;
284+ values. push ( v. to_string ( ) ) ;
285+ }
287286 }
288287
289288 let schema = Arc :: new ( Schema :: new ( vec ! [
@@ -298,13 +297,13 @@ impl DfSessionService {
298297 ] ,
299298 )
300299 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
301- let sc_guard = self . session_context . read ( ) . await ;
300+
302301 let df = sc_guard
303302 . read_batch ( batch)
304303 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
305- drop ( sc_guard) ;
306304 return datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ;
307305 }
306+
308307 _ => {
309308 return Err ( PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
310309 "ERROR" . to_string ( ) ,
@@ -315,13 +314,12 @@ impl DfSessionService {
315314 } ;
316315
317316 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( & var_name, DataType :: Utf8 , false ) ] ) ) ;
318- let batch = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( StringArray :: from( vec![ value] ) ) ] )
317+ let batch = RecordBatch :: try_new ( schema, vec ! [ Arc :: new( StringArray :: from( vec![ value] ) ) ] )
319318 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
320- let sc_guard = self . session_context . read ( ) . await ;
321319 let df = sc_guard
322320 . read_batch ( batch)
323321 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
324- drop ( sc_guard ) ;
322+
325323 datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await
326324 }
327325}
@@ -333,6 +331,7 @@ pub struct Parser {
333331#[ async_trait]
334332impl QueryParser for Parser {
335333 type Statement = LogicalPlan ;
334+
336335 async fn parse_sql ( & self , sql : & str , _types : & [ Type ] ) -> PgWireResult < Self :: Statement > {
337336 let sc_guard = self . session_context . read ( ) . await ;
338337 let state = sc_guard. state ( ) ;
@@ -361,6 +360,7 @@ impl SimpleQueryHandler for DfSessionService {
361360 let stmts = SqlParser :: parse_sql ( & dialect, query)
362361 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
363362 let mut responses = Vec :: with_capacity ( stmts. len ( ) ) ;
363+
364364 for statement in stmts {
365365 let stmt_string = statement. to_string ( ) . trim ( ) . to_owned ( ) ;
366366 if stmt_string. is_empty ( ) {
@@ -387,7 +387,6 @@ impl SimpleQueryHandler for DfSessionService {
387387 . sql ( & stmt_string)
388388 . await
389389 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
390- drop ( sc_guard) ;
391390 let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
392391 responses. push ( Response :: Query ( resp) ) ;
393392 }
@@ -401,9 +400,11 @@ impl SimpleQueryHandler for DfSessionService {
401400impl ExtendedQueryHandler for DfSessionService {
402401 type Statement = LogicalPlan ;
403402 type QueryParser = Parser ;
403+
404404 fn query_parser ( & self ) -> Arc < Self :: QueryParser > {
405405 self . parser . clone ( )
406406 }
407+
407408 async fn do_describe_statement < C > (
408409 & self ,
409410 _client : & mut C ,
@@ -420,6 +421,7 @@ impl ExtendedQueryHandler for DfSessionService {
420421 . get_parameter_types ( )
421422 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
422423 let mut param_types = Vec :: with_capacity ( params. len ( ) ) ;
424+
423425 for param_type in ordered_param_types ( & params) . iter ( ) {
424426 if let Some ( datatype) = param_type {
425427 let pgtype = into_pg_type ( datatype) ?;
@@ -430,6 +432,7 @@ impl ExtendedQueryHandler for DfSessionService {
430432 }
431433 Ok ( DescribeStatementResponse :: new ( param_types, fields) )
432434 }
435+
433436 async fn do_describe_portal < C > (
434437 & self ,
435438 _client : & mut C ,
@@ -444,6 +447,7 @@ impl ExtendedQueryHandler for DfSessionService {
444447 let fields = datatypes:: df_schema_to_pg_fields ( schema. as_ref ( ) , format) ?;
445448 Ok ( DescribePortalResponse :: new ( fields) )
446449 }
450+
447451 async fn do_query < ' a , C > (
448452 & self ,
449453 _client : & mut C ,
@@ -455,6 +459,8 @@ impl ExtendedQueryHandler for DfSessionService {
455459 {
456460 let stmt_string = portal. statement . id . clone ( ) ;
457461 let stmt_upper = stmt_string. to_uppercase ( ) ;
462+
463+ // If the statement is a SET or SHOW, handle it here
458464 if stmt_upper. starts_with ( "SET " ) {
459465 let dialect = GenericDialect { } ;
460466 let stmts = SqlParser :: parse_sql ( & dialect, & stmt_string)
@@ -476,6 +482,8 @@ impl ExtendedQueryHandler for DfSessionService {
476482 return Ok ( Response :: Query ( resp) ) ;
477483 }
478484 }
485+
486+ // Otherwise, treat it as a normal prepared statement
479487 let plan = & portal. statement . statement ;
480488 let param_types = plan
481489 . get_parameter_types ( )
@@ -486,12 +494,13 @@ impl ExtendedQueryHandler for DfSessionService {
486494 . clone ( )
487495 . replace_params_with_values ( & param_values)
488496 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
497+
489498 let sc_guard = self . session_context . read ( ) . await ;
490499 let dataframe = sc_guard
491500 . execute_logical_plan ( plan)
492501 . await
493502 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
494- drop ( sc_guard ) ;
503+
495504 let resp = datatypes:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
496505 Ok ( Response :: Query ( resp) )
497506 }
@@ -500,6 +509,8 @@ impl ExtendedQueryHandler for DfSessionService {
500509fn ordered_param_types (
501510 types : & HashMap < String , Option < DataType > > ,
502511) -> Vec < Option < & DataType > > {
512+ // Datafusion stores the parameters as a map. In our case, the keys will be
513+ // `$1`, `$2` etc. The values will be the parameter types.
503514 let mut types_vec = types. iter ( ) . collect :: < Vec < _ > > ( ) ;
504515 types_vec. sort_by ( |a, b| a. 0 . cmp ( b. 0 ) ) ;
505516 types_vec. into_iter ( ) . map ( |pt| pt. 1 . as_ref ( ) ) . collect ( )
0 commit comments