@@ -2,10 +2,10 @@ use std::collections::HashMap;
22use std:: sync:: Arc ;
33
44use async_trait:: async_trait;
5- use datafusion:: arrow:: array:: StringArray ;
5+ use datafusion:: arrow:: array:: { ListBuilder , StringArray , StringBuilder } ;
66use datafusion:: arrow:: datatypes:: { DataType , Field , Schema } ;
77use datafusion:: arrow:: record_batch:: RecordBatch ;
8- use datafusion:: logical_expr:: LogicalPlan ;
8+ use datafusion:: logical_expr:: { create_udf , ColumnarValue , LogicalPlan , Volatility } ;
99use datafusion:: prelude:: * ;
1010use pgwire:: api:: auth:: noop:: NoopStartupHandler ;
1111use pgwire:: api:: copy:: NoopCopyHandler ;
@@ -17,7 +17,7 @@ use pgwire::api::results::{
1717use pgwire:: api:: stmt:: { QueryParser , StoredStatement } ;
1818use pgwire:: api:: { ClientInfo , NoopErrorHandler , PgWireServerHandlers , Type } ;
1919use pgwire:: error:: { PgWireError , PgWireResult } ;
20- use sqlparser:: ast:: { Expr , Ident , ObjectName , Statement } ;
20+ use sqlparser:: ast:: { Expr , Ident , ObjectName , Statement , ObjectNamePart } ;
2121use sqlparser:: dialect:: GenericDialect ;
2222use sqlparser:: parser:: Parser as SqlParser ;
2323use tokio:: sync:: RwLock ;
@@ -75,12 +75,31 @@ impl DfSessionService {
7575 }
7676 }
7777
78+ /// Call this method to register additional UDFs (such as current_schemas)
79+ pub async fn register_udfs ( & self ) -> datafusion:: error:: Result < ( ) > {
80+ let mut ctx = self . session_context . write ( ) . await ;
81+ register_current_schemas_udf ( & mut ctx) ?;
82+ Ok ( ( ) )
83+ }
84+
85+ /// Helper function to read a custom session variable, returning the default if not set.
86+ async fn session_var ( & self , key : & str , default : & str ) -> String {
87+ self . custom_session_vars
88+ . read ( )
89+ . await
90+ . get ( key)
91+ . cloned ( )
92+ . unwrap_or_else ( || default. to_string ( ) )
93+ }
94+
7895 async fn handle_set ( & self , variable : & ObjectName , value : & [ Expr ] ) -> PgWireResult < ( ) > {
96+ // Join all parts of the ObjectName so that "TIME ZONE" becomes "timezone"
7997 let var_name = variable
8098 . 0
81- . first ( )
82- . map ( |ident| ident. to_string ( ) . to_lowercase ( ) )
83- . unwrap_or_default ( ) ;
99+ . iter ( )
100+ . map ( |ident| ident. to_string ( ) )
101+ . collect :: < String > ( )
102+ . to_lowercase ( ) ;
84103
85104 let value_str = match value. first ( ) {
86105 Some ( Expr :: Value ( v) ) => match & v. value {
@@ -151,90 +170,33 @@ impl DfSessionService {
151170 }
152171
153172 async fn handle_show < ' a > ( & self , variable : & [ Ident ] ) -> PgWireResult < QueryResponse < ' a > > {
173+ // Join all identifiers so that "TIME ZONE" becomes "timezone"
154174 let var_name = variable
155- . first ( )
156- . map ( |ident| ident. to_string ( ) . to_lowercase ( ) )
157- . unwrap_or_default ( ) ;
175+ . iter ( )
176+ . map ( |ident| ident. to_string ( ) )
177+ . collect :: < String > ( )
178+ . to_lowercase ( ) ;
158179
159180 let sc_guard = self . session_context . read ( ) . await ;
160181 let config = sc_guard. state ( ) . config ( ) . options ( ) . clone ( ) ;
161182
162183 let value = match var_name. as_str ( ) {
163- "timezone" => config
184+ // Support both "timezone" and "time" so that pgcli/psql are happy.
185+ "timezone" | "time" => config
164186 . execution
165187 . time_zone
166188 . clone ( )
167189 . unwrap_or_else ( || "UTC" . to_string ( ) ) ,
168- "client_encoding" => self
169- . custom_session_vars
170- . read ( )
171- . await
172- . get ( & var_name)
173- . cloned ( )
174- . unwrap_or_else ( || "UTF8" . to_string ( ) ) ,
175- "search_path" => self
176- . custom_session_vars
177- . read ( )
178- . await
179- . get ( & var_name)
180- . cloned ( )
181- . unwrap_or_else ( || "public" . to_string ( ) ) ,
182- "application_name" => self
183- . custom_session_vars
184- . read ( )
185- . await
186- . get ( & var_name)
187- . cloned ( )
188- . unwrap_or_else ( || "" . to_string ( ) ) ,
189- "datestyle" => self
190- . custom_session_vars
191- . read ( )
192- . await
193- . get ( & var_name)
194- . cloned ( )
195- . unwrap_or_else ( || "ISO, MDY" . to_string ( ) ) ,
196- "client_min_messages" => self
197- . custom_session_vars
198- . read ( )
199- . await
200- . get ( & var_name)
201- . cloned ( )
202- . unwrap_or_else ( || "notice" . to_string ( ) ) ,
203- "extra_float_digits" => self
204- . custom_session_vars
205- . read ( )
206- . await
207- . get ( & var_name)
208- . cloned ( )
209- . unwrap_or_else ( || "3" . to_string ( ) ) ,
210- "standard_conforming_strings" => self
211- . custom_session_vars
212- . read ( )
213- . await
214- . get ( & var_name)
215- . cloned ( )
216- . unwrap_or_else ( || "on" . to_string ( ) ) ,
217- "check_function_bodies" => self
218- . custom_session_vars
219- . read ( )
220- . await
221- . get ( & var_name)
222- . cloned ( )
223- . unwrap_or_else ( || "off" . to_string ( ) ) ,
224- "transaction_read_only" => self
225- . custom_session_vars
226- . read ( )
227- . await
228- . get ( & var_name)
229- . cloned ( )
230- . unwrap_or_else ( || "off" . to_string ( ) ) ,
231- "transaction_isolation" => self
232- . custom_session_vars
233- . read ( )
234- . await
235- . get ( & var_name)
236- . cloned ( )
237- . unwrap_or_else ( || "read committed" . to_string ( ) ) ,
190+ "client_encoding" => self . session_var ( "client_encoding" , "UTF8" ) . await ,
191+ "search_path" => self . session_var ( "search_path" , "public" ) . await ,
192+ "application_name" => self . session_var ( "application_name" , "" ) . await ,
193+ "datestyle" => self . session_var ( "datestyle" , "ISO, MDY" ) . await ,
194+ "client_min_messages" => self . session_var ( "client_min_messages" , "notice" ) . await ,
195+ "extra_float_digits" => self . session_var ( "extra_float_digits" , "3" ) . await ,
196+ "standard_conforming_strings" => self . session_var ( "standard_conforming_strings" , "on" ) . await ,
197+ "check_function_bodies" => self . session_var ( "check_function_bodies" , "off" ) . await ,
198+ "transaction_read_only" => self . session_var ( "transaction_read_only" , "off" ) . await ,
199+ "transaction_isolation" => self . session_var ( "transaction_isolation" , "read committed" ) . await ,
238200
239201 // *** New variables to keep psql happy ***
240202 "server_version" => "14.0" . to_string ( ) ,
@@ -280,6 +242,7 @@ impl DfSessionService {
280242 ( "lc_monetary" , "en_US.UTF-8" ) ,
281243 ( "lc_numeric" , "en_US.UTF-8" ) ,
282244 ( "lc_time" , "en_US.UTF-8" ) ,
245+ ( "time" , "UTC" ) ,
283246 ] ;
284247
285248 for ( k, v) in defaults {
@@ -291,7 +254,11 @@ impl DfSessionService {
291254
292255 let schema = Arc :: new ( Schema :: new ( vec ! [
293256 Field :: new( "name" , DataType :: Utf8 , false ) ,
294- Field :: new( "setting" , DataType :: Utf8 , false ) ,
257+ Field :: new(
258+ "setting" ,
259+ DataType :: List ( Box :: new( Field :: new( "item" , DataType :: Utf8 , true ) ) . into( ) ) ,
260+ false ,
261+ ) ,
295262 ] ) ) ;
296263 let batch = RecordBatch :: try_new (
297264 schema. clone ( ) ,
@@ -366,6 +333,55 @@ impl SimpleQueryHandler for DfSessionService {
366333 where
367334 C : ClientInfo + Unpin + Send + Sync ,
368335 {
336+ let query_trimmed = query. trim ( ) ;
337+ let query_lower = query_trimmed. to_lowercase ( ) ;
338+
339+ // Intercept SELECT current_schemas(...) queries.
340+ if query_lower. starts_with ( "select current_schemas(" ) {
341+ // Build a StringArray with "public"
342+ let mut string_builder = StringBuilder :: new ( ) ;
343+ string_builder. append_value ( "public" ) ;
344+ // Build a ListArray containing "public"
345+ let mut list_builder = ListBuilder :: new ( StringBuilder :: new ( ) ) ;
346+ list_builder. values ( ) . append_value ( "public" ) ;
347+ list_builder. append ( true ) ;
348+ let list_array = list_builder. finish ( ) ;
349+
350+ // Define schema for a single column "current_schemas" of type List(Utf8)
351+ let field = Field :: new (
352+ "current_schemas" ,
353+ DataType :: List ( Box :: new ( Field :: new ( "item" , DataType :: Utf8 , true ) ) . into ( ) ) ,
354+ false ,
355+ ) ;
356+ let schema = Arc :: new ( Schema :: new ( vec ! [ field] ) ) ;
357+ let batch = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( list_array) ] )
358+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
359+ let sc_guard = self . session_context . read ( ) . await ;
360+ let df = sc_guard
361+ . read_batch ( batch)
362+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
363+ let encoded = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
364+ return Ok ( vec ! [ Response :: Query ( encoded) ] ) ;
365+ }
366+
367+ // Intercept SET TIME ZONE commands to handle them directly.
368+ if query_lower. starts_with ( "set time zone" ) {
369+ let parts: Vec < & str > = query_trimmed. split_whitespace ( ) . collect ( ) ;
370+ if parts. len ( ) >= 4 {
371+ let tz = parts[ 3 ] . trim_matches ( '\'' ) . trim_matches ( '"' ) ;
372+ let object_name =
373+ ObjectName ( vec ! [ ObjectNamePart :: Identifier ( Ident :: new( "timezone" ) ) ] ) ;
374+ let expr = Expr :: Value (
375+ sqlparser:: ast:: Value :: SingleQuotedString ( tz. to_string ( ) ) . into ( ) ,
376+ ) ;
377+ self . handle_set ( & object_name, & [ expr] ) . await ?;
378+ return Ok ( vec ! [ Response :: Execution (
379+ pgwire:: api:: results:: Tag :: new( "SET" ) ,
380+ ) ] ) ;
381+ }
382+ }
383+
384+ // Otherwise, process the query normally.
369385 let dialect = GenericDialect { } ;
370386 let stmts = SqlParser :: parse_sql ( & dialect, query)
371387 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
@@ -382,12 +398,12 @@ impl SimpleQueryHandler for DfSessionService {
382398 } => {
383399 let var = match variables {
384400 sqlparser:: ast:: OneOrManyWithParens :: One ( ref name) => name,
385- sqlparser:: ast:: OneOrManyWithParens :: Many ( ref names) => {
386- names. first ( ) . unwrap ( )
387- }
401+ sqlparser:: ast:: OneOrManyWithParens :: Many ( ref names) => names. first ( ) . unwrap ( ) ,
388402 } ;
389403 self . handle_set ( var, & value) . await ?;
390- responses. push ( Response :: Execution ( pgwire:: api:: results:: Tag :: new ( "SET" ) ) ) ;
404+ responses. push ( Response :: Execution (
405+ pgwire:: api:: results:: Tag :: new ( "SET" ) ,
406+ ) ) ;
391407 }
392408 Statement :: ShowVariable { variable } => {
393409 let resp = self . handle_show ( & variable) . await ?;
@@ -427,7 +443,8 @@ impl ExtendedQueryHandler for DfSessionService {
427443 {
428444 let plan = & target. statement ;
429445 let schema = plan. schema ( ) ;
430- let fields = datatypes:: df_schema_to_pg_fields ( schema. as_ref ( ) , & Format :: UnifiedBinary ) ?;
446+ let fields =
447+ datatypes:: df_schema_to_pg_fields ( schema. as_ref ( ) , & Format :: UnifiedBinary ) ?;
431448 let params = plan
432449 . get_parameter_types ( )
433450 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
@@ -485,7 +502,9 @@ impl ExtendedQueryHandler for DfSessionService {
485502 sqlparser:: ast:: OneOrManyWithParens :: Many ( ref names) => names. first ( ) . unwrap ( ) ,
486503 } ;
487504 self . handle_set ( var, value) . await ?;
488- return Ok ( Response :: Execution ( pgwire:: api:: results:: Tag :: new ( "SET" ) ) ) ;
505+ return Ok ( Response :: Execution (
506+ pgwire:: api:: results:: Tag :: new ( "SET" ) ,
507+ ) ) ;
489508 }
490509 } else if stmt_upper. starts_with ( "SHOW " ) {
491510 let dialect = GenericDialect { } ;
@@ -497,7 +516,7 @@ impl ExtendedQueryHandler for DfSessionService {
497516 }
498517 }
499518
500- // Otherwise, treat it as a normal prepared statement
519+ // Otherwise, treat it as a normal prepared statement.
501520 let plan = & portal. statement . statement ;
502521 let param_types = plan
503522 . get_parameter_types ( )
@@ -515,15 +534,46 @@ impl ExtendedQueryHandler for DfSessionService {
515534 . await
516535 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
517536
518- let resp = datatypes:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
537+ let resp =
538+ datatypes:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
519539 Ok ( Response :: Query ( resp) )
520540 }
521541}
522542
523543fn ordered_param_types ( types : & HashMap < String , Option < DataType > > ) -> Vec < Option < & DataType > > {
524- // Datafusion stores the parameters as a map. In our case, the keys will be
525- // `$1`, `$2` etc. The values will be the parameter types.
544+ // Datafusion stores the parameters as a map. In our case, the keys will be
545+ // `$1`, `$2` etc. The values will be the parameter types.
526546 let mut types_vec = types. iter ( ) . collect :: < Vec < _ > > ( ) ;
527547 types_vec. sort_by ( |a, b| a. 0 . cmp ( b. 0 ) ) ;
528548 types_vec. into_iter ( ) . map ( |pt| pt. 1 . as_ref ( ) ) . collect ( )
529549}
550+
551+ /// Register a UDF called `current_schemas` that takes a boolean and returns an array containing "public".
552+ fn register_current_schemas_udf ( ctx : & mut SessionContext ) -> datafusion:: error:: Result < ( ) > {
553+ let current_schemas_fn = Arc :: new ( move |args : & [ ColumnarValue ] | -> datafusion:: error:: Result < ColumnarValue > {
554+ // We ignore the input value; just return a constant list containing "public".
555+ let num_rows = match & args[ 0 ] {
556+ ColumnarValue :: Array ( array) => array. len ( ) ,
557+ ColumnarValue :: Scalar ( _) => 1 ,
558+ } ;
559+ // Build a ListArray containing "public"
560+ let mut list_builder = ListBuilder :: new ( StringBuilder :: new ( ) ) ;
561+ for _ in 0 ..num_rows {
562+ list_builder. values ( ) . append_value ( "public" ) ;
563+ list_builder. append ( true ) ;
564+ }
565+ let list_array = list_builder. finish ( ) ;
566+ Ok ( ColumnarValue :: Array ( Arc :: new ( list_array) ) )
567+ } ) ;
568+
569+ let udf = create_udf (
570+ "current_schemas" ,
571+ vec ! [ DataType :: Boolean ] ,
572+ DataType :: List ( Box :: new ( Field :: new ( "item" , DataType :: Utf8 , true ) ) . into ( ) ) ,
573+ Volatility :: Immutable ,
574+ current_schemas_fn,
575+ ) ;
576+
577+ ctx. register_udf ( udf) ;
578+ Ok ( ( ) )
579+ }
0 commit comments